Skip to main content

rustc_mir_transform/
check_pointers.rs

1use rustc_data_structures::thin_vec::ThinVec;
2use rustc_hir::lang_items::LangItem;
3use rustc_index::IndexVec;
4use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor};
5use rustc_middle::mir::*;
6use rustc_middle::ty::{self, Ty, TyCtxt};
7use tracing::{debug, trace};
8
9/// Details of a pointer check, the condition on which we decide whether to
10/// fail the assert and an [AssertKind] that defines the behavior on failure.
11pub(crate) struct PointerCheck<'tcx> {
12    pub(crate) cond: Operand<'tcx>,
13    pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>,
14}
15
16/// When checking for borrows of field projections (`&(*ptr).a`), we might want
17/// to check for the field type (type of `.a` in the example). This enum defines
18/// the variations (pass the pointer [Ty] or the field [Ty]).
19#[derive(#[automatically_derived]
impl ::core::marker::Copy for BorrowedFieldProjectionMode { }Copy, #[automatically_derived]
impl ::core::clone::Clone for BorrowedFieldProjectionMode {
    #[inline]
    fn clone(&self) -> BorrowedFieldProjectionMode { *self }
}Clone)]
20pub(crate) enum BorrowedFieldProjectionMode {
21    FollowProjections,
22    NoFollowProjections,
23}
24
25/// Utility for adding a check for read/write on every sized, raw pointer.
26///
27/// Visits every read/write access to a [Sized], raw pointer and inserts a
28/// new basic block directly before the pointer access. (Read/write accesses
29/// are determined by the `PlaceContext` of the MIR visitor.) Then calls
30/// `on_finding` to insert the actual logic for a pointer check (e.g. check for
31/// alignment). A check can choose to follow borrows of field projections via
32/// the `field_projection_mode` parameter.
33///
34/// This utility takes care of the right order of blocks, the only thing a
35/// caller must do in `on_finding` is:
36/// - Append [Statement]s to `stmts`.
37/// - Append [LocalDecl]s to `local_decls`.
38/// - Return a [PointerCheck] that contains the condition and an [AssertKind].
39///   The AssertKind must be a panic with `#[rustc_nounwind]`. The condition
40///   should always return the boolean `is_ok`, so evaluate to true in case of
41///   success and fail the check otherwise.
42/// This utility will insert a terminator block that asserts on the condition
43/// and panics on failure.
44pub(crate) fn check_pointers<'tcx, F>(
45    tcx: TyCtxt<'tcx>,
46    body: &mut Body<'tcx>,
47    excluded_pointees: &[Ty<'tcx>],
48    on_finding: F,
49    field_projection_mode: BorrowedFieldProjectionMode,
50) where
51    F: Fn(
52        /* tcx: */ TyCtxt<'tcx>,
53        /* pointer: */ Place<'tcx>,
54        /* pointee_ty: */ Ty<'tcx>,
55        /* context: */ PlaceContext,
56        /* local_decls: */ &mut IndexVec<Local, LocalDecl<'tcx>>,
57        /* stmts: */ &mut Vec<Statement<'tcx>>,
58        /* source_info: */ SourceInfo,
59    ) -> PointerCheck<'tcx>,
60{
61    // This pass emits new panics. If for whatever reason we do not have a panic
62    // implementation, running this pass may cause otherwise-valid code to not compile.
63    if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
64        return;
65    }
66
67    let typing_env = body.typing_env(tcx);
68    let basic_blocks = body.basic_blocks.as_mut();
69    let local_decls = &mut body.local_decls;
70
71    // This operation inserts new blocks. Each insertion changes the Location for all
72    // statements/blocks after. Iterating or visiting the MIR in order would require updating
73    // our current location after every insertion. By iterating backwards, we dodge this issue:
74    // The only Locations that an insertion changes have already been handled.
75    for block in basic_blocks.indices().rev() {
76        for statement_index in (0..basic_blocks[block].statements.len()).rev() {
77            let location = Location { block, statement_index };
78            let statement = &basic_blocks[block].statements[statement_index];
79            let source_info = statement.source_info;
80
81            let mut finder = PointerFinder::new(
82                tcx,
83                local_decls,
84                typing_env,
85                excluded_pointees,
86                field_projection_mode,
87            );
88            finder.visit_statement(statement, location);
89
90            for (local, ty, context) in finder.into_found_pointers() {
91                {
    use ::tracing::__macro_support::Callsite as _;
    static __CALLSITE: ::tracing::callsite::DefaultCallsite =
        {
            static META: ::tracing::Metadata<'static> =
                {
                    ::tracing_core::metadata::Metadata::new("event compiler/rustc_mir_transform/src/check_pointers.rs:91",
                        "rustc_mir_transform::check_pointers",
                        ::tracing::Level::DEBUG,
                        ::tracing_core::__macro_support::Option::Some("compiler/rustc_mir_transform/src/check_pointers.rs"),
                        ::tracing_core::__macro_support::Option::Some(91u32),
                        ::tracing_core::__macro_support::Option::Some("rustc_mir_transform::check_pointers"),
                        ::tracing_core::field::FieldSet::new(&["message"],
                            ::tracing_core::callsite::Identifier(&__CALLSITE)),
                        ::tracing::metadata::Kind::EVENT)
                };
            ::tracing::callsite::DefaultCallsite::new(&META)
        };
    let enabled =
        ::tracing::Level::DEBUG <= ::tracing::level_filters::STATIC_MAX_LEVEL
                &&
                ::tracing::Level::DEBUG <=
                    ::tracing::level_filters::LevelFilter::current() &&
            {
                let interest = __CALLSITE.interest();
                !interest.is_never() &&
                    ::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
                        interest)
            };
    if enabled {
        (|value_set: ::tracing::field::ValueSet|
                    {
                        let meta = __CALLSITE.metadata();
                        ::tracing::Event::dispatch(meta, &value_set);
                        ;
                    })({
                #[allow(unused_imports)]
                use ::tracing::field::{debug, display, Value};
                let mut iter = __CALLSITE.metadata().fields().iter();
                __CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
                                    ::tracing::__macro_support::Option::Some(&format_args!("Inserting check for {0:?}",
                                                    ty) as &dyn Value))])
            });
    } else { ; }
};debug!("Inserting check for {:?}", ty);
92                let new_block = split_block(basic_blocks, location);
93
94                // Invoke `on_finding` which appends to `local_decls` and the
95                // blocks statements. It returns information about the assert
96                // we're performing in the Terminator.
97                let block_data = &mut basic_blocks[block];
98                let pointer_check = on_finding(
99                    tcx,
100                    local,
101                    ty,
102                    context,
103                    local_decls,
104                    &mut block_data.statements,
105                    source_info,
106                );
107                block_data.terminator = Some(Terminator {
108                    source_info,
109                    kind: TerminatorKind::Assert {
110                        cond: pointer_check.cond,
111                        expected: true,
112                        target: new_block,
113                        msg: pointer_check.assert_kind,
114                        // This calls a panic function associated with the pointer check, which
115                        // is #[rustc_nounwind]. We never want to insert an unwind into unsafe
116                        // code, because unwinding could make a failing UB check turn into much
117                        // worse UB when we start unwinding.
118                        unwind: UnwindAction::Unreachable,
119                    },
120                    attributes: ThinVec::new(),
121                });
122            }
123        }
124    }
125}
126
127struct PointerFinder<'a, 'tcx> {
128    tcx: TyCtxt<'tcx>,
129    local_decls: &'a mut LocalDecls<'tcx>,
130    typing_env: ty::TypingEnv<'tcx>,
131    pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>,
132    excluded_pointees: &'a [Ty<'tcx>],
133    field_projection_mode: BorrowedFieldProjectionMode,
134}
135
136impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
137    fn new(
138        tcx: TyCtxt<'tcx>,
139        local_decls: &'a mut LocalDecls<'tcx>,
140        typing_env: ty::TypingEnv<'tcx>,
141        excluded_pointees: &'a [Ty<'tcx>],
142        field_projection_mode: BorrowedFieldProjectionMode,
143    ) -> Self {
144        PointerFinder {
145            tcx,
146            local_decls,
147            typing_env,
148            excluded_pointees,
149            pointers: Vec::new(),
150            field_projection_mode,
151        }
152    }
153
154    fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)> {
155        self.pointers
156    }
157
158    /// Whether or not we should visit a [Place] with [PlaceContext].
159    ///
160    /// We generally only visit Reads/Writes to a place and only Borrows if
161    /// requested.
162    fn should_visit_place(&self, context: PlaceContext) -> bool {
163        match context {
164            PlaceContext::MutatingUse(
165                MutatingUseContext::Store
166                | MutatingUseContext::Call
167                | MutatingUseContext::Yield
168                | MutatingUseContext::Drop
169                | MutatingUseContext::Borrow,
170            ) => true,
171            PlaceContext::NonMutatingUse(
172                NonMutatingUseContext::Copy
173                | NonMutatingUseContext::Move
174                | NonMutatingUseContext::SharedBorrow,
175            ) => true,
176            _ => false,
177        }
178    }
179}
180
181impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
182    fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
183        if !self.should_visit_place(context) || !place.is_indirect() {
184            return;
185        }
186
187        // Get the place and type we visit.
188        let pointer = Place::from(place.local);
189        let pointer_ty = pointer.ty(self.local_decls, self.tcx).ty;
190
191        // We only want to check places based on raw pointers
192        let &ty::RawPtr(mut pointee_ty, _) = pointer_ty.kind() else {
193            {
    use ::tracing::__macro_support::Callsite as _;
    static __CALLSITE: ::tracing::callsite::DefaultCallsite =
        {
            static META: ::tracing::Metadata<'static> =
                {
                    ::tracing_core::metadata::Metadata::new("event compiler/rustc_mir_transform/src/check_pointers.rs:193",
                        "rustc_mir_transform::check_pointers",
                        ::tracing::Level::TRACE,
                        ::tracing_core::__macro_support::Option::Some("compiler/rustc_mir_transform/src/check_pointers.rs"),
                        ::tracing_core::__macro_support::Option::Some(193u32),
                        ::tracing_core::__macro_support::Option::Some("rustc_mir_transform::check_pointers"),
                        ::tracing_core::field::FieldSet::new(&["message"],
                            ::tracing_core::callsite::Identifier(&__CALLSITE)),
                        ::tracing::metadata::Kind::EVENT)
                };
            ::tracing::callsite::DefaultCallsite::new(&META)
        };
    let enabled =
        ::tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
                &&
                ::tracing::Level::TRACE <=
                    ::tracing::level_filters::LevelFilter::current() &&
            {
                let interest = __CALLSITE.interest();
                !interest.is_never() &&
                    ::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
                        interest)
            };
    if enabled {
        (|value_set: ::tracing::field::ValueSet|
                    {
                        let meta = __CALLSITE.metadata();
                        ::tracing::Event::dispatch(meta, &value_set);
                        ;
                    })({
                #[allow(unused_imports)]
                use ::tracing::field::{debug, display, Value};
                let mut iter = __CALLSITE.metadata().fields().iter();
                __CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
                                    ::tracing::__macro_support::Option::Some(&format_args!("Indirect, but not based on an raw ptr, not checking {0:?}",
                                                    place) as &dyn Value))])
            });
    } else { ; }
};trace!("Indirect, but not based on an raw ptr, not checking {:?}", place);
194            return;
195        };
196
197        // If we see a borrow of a field projection, we want to pass the field type to the
198        // check and not the pointee type.
199        if #[allow(non_exhaustive_omitted_patterns)] match self.field_projection_mode {
    BorrowedFieldProjectionMode::FollowProjections => true,
    _ => false,
}matches!(self.field_projection_mode, BorrowedFieldProjectionMode::FollowProjections)
200            && #[allow(non_exhaustive_omitted_patterns)] match context {
    PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) |
        PlaceContext::MutatingUse(MutatingUseContext::Borrow) => true,
    _ => false,
}matches!(
201                context,
202                PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow)
203                    | PlaceContext::MutatingUse(MutatingUseContext::Borrow)
204            )
205        {
206            // Naturally, the field type is type of the initial place we look at.
207            pointee_ty = place.ty(self.local_decls, self.tcx).ty;
208        }
209
210        // Ideally we'd support this in the future, but for now we are limited to sized types.
211        if !pointee_ty.is_sized(self.tcx, self.typing_env) {
212            {
    use ::tracing::__macro_support::Callsite as _;
    static __CALLSITE: ::tracing::callsite::DefaultCallsite =
        {
            static META: ::tracing::Metadata<'static> =
                {
                    ::tracing_core::metadata::Metadata::new("event compiler/rustc_mir_transform/src/check_pointers.rs:212",
                        "rustc_mir_transform::check_pointers",
                        ::tracing::Level::TRACE,
                        ::tracing_core::__macro_support::Option::Some("compiler/rustc_mir_transform/src/check_pointers.rs"),
                        ::tracing_core::__macro_support::Option::Some(212u32),
                        ::tracing_core::__macro_support::Option::Some("rustc_mir_transform::check_pointers"),
                        ::tracing_core::field::FieldSet::new(&["message"],
                            ::tracing_core::callsite::Identifier(&__CALLSITE)),
                        ::tracing::metadata::Kind::EVENT)
                };
            ::tracing::callsite::DefaultCallsite::new(&META)
        };
    let enabled =
        ::tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
                &&
                ::tracing::Level::TRACE <=
                    ::tracing::level_filters::LevelFilter::current() &&
            {
                let interest = __CALLSITE.interest();
                !interest.is_never() &&
                    ::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
                        interest)
            };
    if enabled {
        (|value_set: ::tracing::field::ValueSet|
                    {
                        let meta = __CALLSITE.metadata();
                        ::tracing::Event::dispatch(meta, &value_set);
                        ;
                    })({
                #[allow(unused_imports)]
                use ::tracing::field::{debug, display, Value};
                let mut iter = __CALLSITE.metadata().fields().iter();
                __CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
                                    ::tracing::__macro_support::Option::Some(&format_args!("Raw pointer, but pointee is not known to be sized: {0:?}",
                                                    pointer_ty) as &dyn Value))])
            });
    } else { ; }
};trace!("Raw pointer, but pointee is not known to be sized: {:?}", pointer_ty);
213            return;
214        }
215
216        // We don't need to look for slices, we already rejected unsized types above.
217        let element_ty = match pointee_ty.kind() {
218            ty::Array(ty, _) => *ty,
219            _ => pointee_ty,
220        };
221        // Check if we excluded this pointee type from the check.
222        if self.excluded_pointees.contains(&element_ty) {
223            {
    use ::tracing::__macro_support::Callsite as _;
    static __CALLSITE: ::tracing::callsite::DefaultCallsite =
        {
            static META: ::tracing::Metadata<'static> =
                {
                    ::tracing_core::metadata::Metadata::new("event compiler/rustc_mir_transform/src/check_pointers.rs:223",
                        "rustc_mir_transform::check_pointers",
                        ::tracing::Level::TRACE,
                        ::tracing_core::__macro_support::Option::Some("compiler/rustc_mir_transform/src/check_pointers.rs"),
                        ::tracing_core::__macro_support::Option::Some(223u32),
                        ::tracing_core::__macro_support::Option::Some("rustc_mir_transform::check_pointers"),
                        ::tracing_core::field::FieldSet::new(&["message"],
                            ::tracing_core::callsite::Identifier(&__CALLSITE)),
                        ::tracing::metadata::Kind::EVENT)
                };
            ::tracing::callsite::DefaultCallsite::new(&META)
        };
    let enabled =
        ::tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
                &&
                ::tracing::Level::TRACE <=
                    ::tracing::level_filters::LevelFilter::current() &&
            {
                let interest = __CALLSITE.interest();
                !interest.is_never() &&
                    ::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
                        interest)
            };
    if enabled {
        (|value_set: ::tracing::field::ValueSet|
                    {
                        let meta = __CALLSITE.metadata();
                        ::tracing::Event::dispatch(meta, &value_set);
                        ;
                    })({
                #[allow(unused_imports)]
                use ::tracing::field::{debug, display, Value};
                let mut iter = __CALLSITE.metadata().fields().iter();
                __CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
                                    ::tracing::__macro_support::Option::Some(&format_args!("Skipping pointer for type: {0:?}",
                                                    pointee_ty) as &dyn Value))])
            });
    } else { ; }
};trace!("Skipping pointer for type: {:?}", pointee_ty);
224            return;
225        }
226
227        self.pointers.push((pointer, pointee_ty, context));
228
229        self.super_place(place, context, location);
230    }
231}
232
233fn split_block(
234    basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
235    location: Location,
236) -> BasicBlock {
237    let block_data = &mut basic_blocks[location.block];
238
239    // Drain every statement after this one and move the current terminator to a new basic block.
240    let new_block = BasicBlockData::new_stmts(
241        block_data.statements.split_off(location.statement_index),
242        block_data.terminator.take(),
243        block_data.is_cleanup,
244    );
245
246    basic_blocks.push(new_block)
247}