1use rustc_hir::lang_items::LangItem;
2use rustc_index::IndexVec;
3use rustc_middle::mir::visit::{MutatingUseContext, NonMutatingUseContext, PlaceContext, Visitor};
4use rustc_middle::mir::*;
5use rustc_middle::ty::{self, Ty, TyCtxt};
6use tracing::{debug, trace};
7
8pub(crate) struct PointerCheck<'tcx> {
11 pub(crate) cond: Operand<'tcx>,
12 pub(crate) assert_kind: Box<AssertKind<Operand<'tcx>>>,
13}
14
15#[derive(Copy, Clone)]
19pub(crate) enum BorrowCheckMode {
20 IncludeBorrows,
21 ExcludeBorrows,
22}
23
24pub(crate) fn check_pointers<'tcx, F>(
44 tcx: TyCtxt<'tcx>,
45 body: &mut Body<'tcx>,
46 excluded_pointees: &[Ty<'tcx>],
47 on_finding: F,
48 borrow_check_mode: BorrowCheckMode,
49) where
50 F: Fn(
51 TyCtxt<'tcx>,
52 Place<'tcx>,
53 Ty<'tcx>,
54 PlaceContext,
55 &mut IndexVec<Local, LocalDecl<'tcx>>,
56 &mut Vec<Statement<'tcx>>,
57 SourceInfo,
58 ) -> PointerCheck<'tcx>,
59{
60 if tcx.lang_items().get(LangItem::PanicImpl).is_none() {
63 return;
64 }
65
66 let typing_env = body.typing_env(tcx);
67 let basic_blocks = body.basic_blocks.as_mut();
68 let local_decls = &mut body.local_decls;
69
70 for block in (0..basic_blocks.len()).rev() {
75 let block = block.into();
76 for statement_index in (0..basic_blocks[block].statements.len()).rev() {
77 let location = Location { block, statement_index };
78 let statement = &basic_blocks[block].statements[statement_index];
79 let source_info = statement.source_info;
80
81 let mut finder = PointerFinder::new(
82 tcx,
83 local_decls,
84 typing_env,
85 excluded_pointees,
86 borrow_check_mode,
87 );
88 finder.visit_statement(statement, location);
89
90 for (local, ty, context) in finder.into_found_pointers() {
91 debug!("Inserting check for {:?}", ty);
92 let new_block = split_block(basic_blocks, location);
93
94 let block_data = &mut basic_blocks[block];
98 let pointer_check = on_finding(
99 tcx,
100 local,
101 ty,
102 context,
103 local_decls,
104 &mut block_data.statements,
105 source_info,
106 );
107 block_data.terminator = Some(Terminator {
108 source_info,
109 kind: TerminatorKind::Assert {
110 cond: pointer_check.cond,
111 expected: true,
112 target: new_block,
113 msg: pointer_check.assert_kind,
114 unwind: UnwindAction::Unreachable,
119 },
120 });
121 }
122 }
123 }
124}
125
126struct PointerFinder<'a, 'tcx> {
127 tcx: TyCtxt<'tcx>,
128 local_decls: &'a mut LocalDecls<'tcx>,
129 typing_env: ty::TypingEnv<'tcx>,
130 pointers: Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)>,
131 excluded_pointees: &'a [Ty<'tcx>],
132 borrow_check_mode: BorrowCheckMode,
133}
134
135impl<'a, 'tcx> PointerFinder<'a, 'tcx> {
136 fn new(
137 tcx: TyCtxt<'tcx>,
138 local_decls: &'a mut LocalDecls<'tcx>,
139 typing_env: ty::TypingEnv<'tcx>,
140 excluded_pointees: &'a [Ty<'tcx>],
141 borrow_check_mode: BorrowCheckMode,
142 ) -> Self {
143 PointerFinder {
144 tcx,
145 local_decls,
146 typing_env,
147 excluded_pointees,
148 pointers: Vec::new(),
149 borrow_check_mode,
150 }
151 }
152
153 fn into_found_pointers(self) -> Vec<(Place<'tcx>, Ty<'tcx>, PlaceContext)> {
154 self.pointers
155 }
156
157 fn should_visit_place(&self, context: PlaceContext) -> bool {
162 match context {
163 PlaceContext::MutatingUse(
164 MutatingUseContext::Store
165 | MutatingUseContext::Call
166 | MutatingUseContext::Yield
167 | MutatingUseContext::Drop,
168 ) => true,
169 PlaceContext::NonMutatingUse(
170 NonMutatingUseContext::Copy | NonMutatingUseContext::Move,
171 ) => true,
172 PlaceContext::MutatingUse(MutatingUseContext::Borrow)
173 | PlaceContext::NonMutatingUse(NonMutatingUseContext::SharedBorrow) => {
174 matches!(self.borrow_check_mode, BorrowCheckMode::IncludeBorrows)
175 }
176 _ => false,
177 }
178 }
179}
180
181impl<'a, 'tcx> Visitor<'tcx> for PointerFinder<'a, 'tcx> {
182 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
183 if !self.should_visit_place(context) || !place.is_indirect() {
184 return;
185 }
186
187 let pointer = Place::from(place.local);
190 let pointer_ty = self.local_decls[place.local].ty;
191
192 if !pointer_ty.is_raw_ptr() {
194 trace!("Indirect, but not based on an raw ptr, not checking {:?}", place);
195 return;
196 }
197
198 let pointee_ty =
199 pointer_ty.builtin_deref(true).expect("no builtin_deref for an raw pointer");
200 if !pointee_ty.is_sized(self.tcx, self.typing_env) {
202 trace!("Raw pointer, but pointee is not known to be sized: {:?}", pointer_ty);
203 return;
204 }
205
206 let element_ty = match pointee_ty.kind() {
208 ty::Array(ty, _) => *ty,
209 _ => pointee_ty,
210 };
211 if self.excluded_pointees.contains(&element_ty) {
212 trace!("Skipping pointer for type: {:?}", pointee_ty);
213 return;
214 }
215
216 self.pointers.push((pointer, pointee_ty, context));
217
218 self.super_place(place, context, location);
219 }
220}
221
222fn split_block(
223 basic_blocks: &mut IndexVec<BasicBlock, BasicBlockData<'_>>,
224 location: Location,
225) -> BasicBlock {
226 let block_data = &mut basic_blocks[location.block];
227
228 let new_block = BasicBlockData {
230 statements: block_data.statements.split_off(location.statement_index),
231 terminator: block_data.terminator.take(),
232 is_cleanup: block_data.is_cleanup,
233 };
234
235 basic_blocks.push(new_block)
236}