1use std::sync::Arc;
2
3use thin_vec::thin_vec;
4
5use crate::LoweringContext;
6
7impl<'a, 'hir> LoweringContext<'a, 'hir> {
8 pub(super) fn lower_contract(
19 &mut self,
20 body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
21 contract: &rustc_ast::FnContract,
22 ) -> rustc_hir::Expr<'hir> {
23 let contract_decls = self.lower_decls(contract);
27
28 match (&contract.requires, &contract.ensures) {
29 (Some(req), Some(ens)) => {
30 let precond = self.lower_precond(req);
54 let postcond_checker = self.lower_postcond_checker(ens);
55
56 let contract_check = self.lower_contract_check_with_postcond(
57 contract_decls,
58 Some(precond),
59 postcond_checker,
60 );
61
62 let wrapped_body =
63 self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
64 self.expr_block(wrapped_body)
65 }
66 (None, Some(ens)) => {
67 let postcond_checker = self.lower_postcond_checker(ens);
89 let contract_check =
90 self.lower_contract_check_with_postcond(contract_decls, None, postcond_checker);
91
92 let wrapped_body =
93 self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
94 self.expr_block(wrapped_body)
95 }
96 (Some(req), None) => {
97 let precond = self.lower_precond(req);
111 let precond_check = self.lower_contract_check_just_precond(contract_decls, precond);
112
113 let body = self.arena.alloc(body(self));
114
115 let wrapped_body = self.block_all(
117 body.span,
118 self.arena.alloc_from_iter([precond_check].into_iter()),
119 Some(body),
120 );
121 self.expr_block(wrapped_body)
122 }
123 (None, None) => body(self),
124 }
125 }
126
127 fn lower_decls(&mut self, contract: &rustc_ast::FnContract) -> &'hir [rustc_hir::Stmt<'hir>] {
128 let (decls, decls_tail) = self.lower_stmts(&contract.declarations);
129
130 if let Some(e) = decls_tail {
131 let tail = self.stmt_expr(e.span, *e);
133 self.arena.alloc_from_iter(decls.into_iter().map(|d| *d).chain([tail].into_iter()))
134 } else {
135 decls
136 }
137 }
138
139 fn lower_precond(&mut self, req: &Box<rustc_ast::Expr>) -> rustc_hir::Stmt<'hir> {
141 let lowered_req = self.lower_expr_mut(&req);
142 let req_span = self.mark_span_with_reason(
143 rustc_span::DesugaringKind::Contract,
144 lowered_req.span,
145 Some(Arc::clone(&self.allow_contracts)),
146 );
147 let precond = self.expr_call_lang_item_fn_mut(
148 req_span,
149 rustc_hir::LangItem::ContractCheckRequires,
150 &*arena_vec![self; lowered_req],
151 );
152 self.stmt_expr(req.span, precond)
153 }
154
155 fn lower_postcond_checker(
156 &mut self,
157 ens: &Box<rustc_ast::Expr>,
158 ) -> &'hir rustc_hir::Expr<'hir> {
159 let ens_span = self.lower_span(ens.span);
160 let ens_span = self.mark_span_with_reason(
161 rustc_span::DesugaringKind::Contract,
162 ens_span,
163 Some(Arc::clone(&self.allow_contracts)),
164 );
165 let lowered_ens = self.lower_expr_mut(&ens);
166 self.expr_call_lang_item_fn(
167 ens_span,
168 rustc_hir::LangItem::ContractBuildCheckEnsures,
169 &*arena_vec![self; lowered_ens],
170 )
171 }
172
173 fn lower_contract_check_just_precond(
174 &mut self,
175 contract_decls: &'hir [rustc_hir::Stmt<'hir>],
176 precond: rustc_hir::Stmt<'hir>,
177 ) -> rustc_hir::Stmt<'hir> {
178 let stmts = self
179 .arena
180 .alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain([precond].into_iter()));
181
182 let then_block_stmts = self.block_all(precond.span, stmts, None);
183 let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
184
185 let precond_check = rustc_hir::ExprKind::If(
186 self.arena.alloc(self.expr_bool_literal(precond.span, self.tcx.sess.contract_checks())),
187 then_block,
188 None,
189 );
190
191 let precond_check = self.expr(precond.span, precond_check);
192 self.stmt_expr(precond.span, precond_check)
193 }
194
195 fn lower_contract_check_with_postcond(
196 &mut self,
197 contract_decls: &'hir [rustc_hir::Stmt<'hir>],
198 precond: Option<rustc_hir::Stmt<'hir>>,
199 postcond_checker: &'hir rustc_hir::Expr<'hir>,
200 ) -> &'hir rustc_hir::Expr<'hir> {
201 let stmts = self
202 .arena
203 .alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain(precond.into_iter()));
204 let span = match precond {
205 Some(precond) => precond.span,
206 None => postcond_checker.span,
207 };
208
209 let postcond_checker = self.arena.alloc(self.expr_enum_variant_lang_item(
210 postcond_checker.span,
211 rustc_hir::lang_items::LangItem::OptionSome,
212 &*arena_vec![self; *postcond_checker],
213 ));
214 let then_block_stmts = self.block_all(span, stmts, Some(postcond_checker));
215 let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
216
217 let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
218 postcond_checker.span,
219 rustc_hir::lang_items::LangItem::OptionNone,
220 Default::default(),
221 ));
222 let else_block = self.block_expr(none_expr);
223 let else_block = self.arena.alloc(self.expr_block(else_block));
224
225 let contract_check = rustc_hir::ExprKind::If(
226 self.arena.alloc(self.expr_bool_literal(span, self.tcx.sess.contract_checks())),
227 then_block,
228 Some(else_block),
229 );
230 self.arena.alloc(self.expr(span, contract_check))
231 }
232
233 fn wrap_body_with_contract_check(
234 &mut self,
235 body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
236 contract_check: &'hir rustc_hir::Expr<'hir>,
237 postcond_span: rustc_span::Span,
238 ) -> &'hir rustc_hir::Block<'hir> {
239 let check_ident: rustc_span::Ident =
240 rustc_span::Ident::from_str_and_span("__ensures_checker", postcond_span);
241 let (check_hir_id, postcond_decl) = {
242 let (checker_pat, check_hir_id) = self.pat_ident_binding_mode_mut(
244 postcond_span,
245 check_ident,
246 rustc_hir::BindingMode::NONE,
247 );
248 (
249 check_hir_id,
250 self.stmt_let_pat(
251 None,
252 postcond_span,
253 Some(contract_check),
254 self.arena.alloc(checker_pat),
255 rustc_hir::LocalSource::Contract,
256 ),
257 )
258 };
259
260 self.contract_ensures = Some((postcond_span, check_ident, check_hir_id));
263 let body = self.arena.alloc(body(self));
264
265 let body = self.inject_ensures_check(body, postcond_span, check_ident, check_hir_id);
267
268 let wrapped_body = self.block_all(
270 body.span,
271 self.arena.alloc_from_iter([postcond_decl].into_iter()),
272 Some(body),
273 );
274 wrapped_body
275 }
276
277 pub(super) fn checked_return(
280 &mut self,
281 opt_expr: Option<&'hir rustc_hir::Expr<'hir>>,
282 ) -> rustc_hir::ExprKind<'hir> {
283 let checked_ret =
284 if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
285 let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
286 Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
287 } else {
288 opt_expr
289 };
290 rustc_hir::ExprKind::Ret(checked_ret)
291 }
292
293 pub(super) fn inject_ensures_check(
295 &mut self,
296 expr: &'hir rustc_hir::Expr<'hir>,
297 span: rustc_span::Span,
298 cond_ident: rustc_span::Ident,
299 cond_hir_id: rustc_hir::HirId,
300 ) -> &'hir rustc_hir::Expr<'hir> {
301 let ret_ident: rustc_span::Ident = rustc_span::Ident::from_str_and_span("__ret", span);
311
312 let (ret_pat, ret_hir_id) =
314 self.pat_ident_binding_mode_mut(span, ret_ident, rustc_hir::BindingMode::NONE);
315
316 let ret_stmt = self.stmt_let_pat(
317 None,
318 span,
319 Some(expr),
320 self.arena.alloc(ret_pat),
321 rustc_hir::LocalSource::Contract,
322 );
323
324 let ret = self.expr_ident(span, ret_ident, ret_hir_id);
325
326 let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
327 let contract_check = self.expr_call_lang_item_fn_mut(
328 span,
329 rustc_hir::LangItem::ContractCheckEnsures,
330 arena_vec![self; *cond_fn, *ret],
331 );
332 let contract_check = self.arena.alloc(contract_check);
333 let call_expr = self.block_expr_block(contract_check);
334
335 let ret = self.expr_ident(span, ret_ident, ret_hir_id);
338 let ret_block = self.block_expr_block(ret);
339
340 let contracts_enabled: rustc_hir::Expr<'_> =
341 self.expr_bool_literal(span, self.tcx.sess.contract_checks());
342 let contract_check = self.arena.alloc(self.expr(
343 span,
344 rustc_hir::ExprKind::If(
345 self.arena.alloc(contracts_enabled),
346 call_expr,
347 Some(ret_block),
348 ),
349 ));
350
351 let attrs: rustc_ast::AttrVec = thin_vec![self.unreachable_code_attr(span)];
352 self.lower_attrs(contract_check.hir_id, &attrs, span, rustc_hir::Target::Expression);
353
354 let ret_block = self.block_all(span, arena_vec![self; ret_stmt], Some(contract_check));
355 self.arena.alloc(self.expr_block(self.arena.alloc(ret_block)))
356 }
357}