rustc_ast_lowering/
contract.rs

1use std::sync::Arc;
2
3use thin_vec::thin_vec;
4
5use crate::LoweringContext;
6
7impl<'a, 'hir> LoweringContext<'a, 'hir> {
8    /// Lowered contracts are guarded with the `contract_checks` compiler flag,
9    /// i.e. the flag turns into a boolean guard in the lowered HIR. The reason
10    /// for not eliminating the contract code entirely when the `contract_checks`
11    /// flag is disabled is so that contracts can be type checked, even when
12    /// they are disabled, which avoids them becoming stale (i.e. out of sync
13    /// with the codebase) over time.
14    ///
15    /// The optimiser should be able to eliminate all contract code guarded
16    /// by `if false`, leaving the original body intact when runtime contract
17    /// checks are disabled.
18    pub(super) fn lower_contract(
19        &mut self,
20        body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
21        contract: &rustc_ast::FnContract,
22    ) -> rustc_hir::Expr<'hir> {
23        // The order in which things are lowered is important! I.e to
24        // refer to variables in contract_decls from postcond/precond,
25        // we must lower it first!
26        let contract_decls = self.lower_decls(contract);
27
28        match (&contract.requires, &contract.ensures) {
29            (Some(req), Some(ens)) => {
30                // Lower the fn contract, which turns:
31                //
32                // { body }
33                //
34                // into:
35                //
36                // let __postcond = if contract_checks {
37                //     CONTRACT_DECLARATIONS;
38                //     contract_check_requires(PRECOND);
39                //     Some(|ret_val| POSTCOND)
40                // } else {
41                //     None
42                // };
43                // {
44                //     let ret = { body };
45                //
46                //     if contract_checks {
47                //         contract_check_ensures(__postcond, ret)
48                //     } else {
49                //         ret
50                //     }
51                // }
52
53                let precond = self.lower_precond(req);
54                let postcond_checker = self.lower_postcond_checker(ens);
55
56                let contract_check = self.lower_contract_check_with_postcond(
57                    contract_decls,
58                    Some(precond),
59                    postcond_checker,
60                );
61
62                let wrapped_body =
63                    self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
64                self.expr_block(wrapped_body)
65            }
66            (None, Some(ens)) => {
67                // Lower the fn contract, which turns:
68                //
69                // { body }
70                //
71                // into:
72                //
73                // let __postcond = if contract_checks {
74                //     Some(|ret_val| POSTCOND)
75                // } else {
76                //     None
77                // };
78                // {
79                //     let ret = { body };
80                //
81                //     if contract_checks {
82                //         CONTRACT_DECLARATIONS;
83                //         contract_check_ensures(__postcond, ret)
84                //     } else {
85                //         ret
86                //     }
87                // }
88                let postcond_checker = self.lower_postcond_checker(ens);
89                let contract_check =
90                    self.lower_contract_check_with_postcond(contract_decls, None, postcond_checker);
91
92                let wrapped_body =
93                    self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
94                self.expr_block(wrapped_body)
95            }
96            (Some(req), None) => {
97                // Lower the fn contract, which turns:
98                //
99                // { body }
100                //
101                // into:
102                //
103                // {
104                //      if contracts_checks {
105                //          CONTRACT_DECLARATIONS;
106                //          contract_requires(PRECOND);
107                //      }
108                //      body
109                // }
110                let precond = self.lower_precond(req);
111                let precond_check = self.lower_contract_check_just_precond(contract_decls, precond);
112
113                let body = self.arena.alloc(body(self));
114
115                // Flatten the body into precond check, then body.
116                let wrapped_body = self.block_all(
117                    body.span,
118                    self.arena.alloc_from_iter([precond_check].into_iter()),
119                    Some(body),
120                );
121                self.expr_block(wrapped_body)
122            }
123            (None, None) => body(self),
124        }
125    }
126
127    fn lower_decls(&mut self, contract: &rustc_ast::FnContract) -> &'hir [rustc_hir::Stmt<'hir>] {
128        let (decls, decls_tail) = self.lower_stmts(&contract.declarations);
129
130        if let Some(e) = decls_tail {
131            // include the tail expression in the declaration statements
132            let tail = self.stmt_expr(e.span, *e);
133            self.arena.alloc_from_iter(decls.into_iter().map(|d| *d).chain([tail].into_iter()))
134        } else {
135            decls
136        }
137    }
138
139    /// Lower the precondition check intrinsic.
140    fn lower_precond(&mut self, req: &Box<rustc_ast::Expr>) -> rustc_hir::Stmt<'hir> {
141        let lowered_req = self.lower_expr_mut(&req);
142        let req_span = self.mark_span_with_reason(
143            rustc_span::DesugaringKind::Contract,
144            lowered_req.span,
145            Some(Arc::clone(&self.allow_contracts)),
146        );
147        let precond = self.expr_call_lang_item_fn_mut(
148            req_span,
149            rustc_hir::LangItem::ContractCheckRequires,
150            &*arena_vec![self; lowered_req],
151        );
152        self.stmt_expr(req.span, precond)
153    }
154
155    fn lower_postcond_checker(
156        &mut self,
157        ens: &Box<rustc_ast::Expr>,
158    ) -> &'hir rustc_hir::Expr<'hir> {
159        let ens_span = self.lower_span(ens.span);
160        let ens_span = self.mark_span_with_reason(
161            rustc_span::DesugaringKind::Contract,
162            ens_span,
163            Some(Arc::clone(&self.allow_contracts)),
164        );
165        let lowered_ens = self.lower_expr_mut(&ens);
166        self.expr_call_lang_item_fn(
167            ens_span,
168            rustc_hir::LangItem::ContractBuildCheckEnsures,
169            &*arena_vec![self; lowered_ens],
170        )
171    }
172
173    fn lower_contract_check_just_precond(
174        &mut self,
175        contract_decls: &'hir [rustc_hir::Stmt<'hir>],
176        precond: rustc_hir::Stmt<'hir>,
177    ) -> rustc_hir::Stmt<'hir> {
178        let stmts = self
179            .arena
180            .alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain([precond].into_iter()));
181
182        let then_block_stmts = self.block_all(precond.span, stmts, None);
183        let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
184
185        let precond_check = rustc_hir::ExprKind::If(
186            self.arena.alloc(self.expr_bool_literal(precond.span, self.tcx.sess.contract_checks())),
187            then_block,
188            None,
189        );
190
191        let precond_check = self.expr(precond.span, precond_check);
192        self.stmt_expr(precond.span, precond_check)
193    }
194
195    fn lower_contract_check_with_postcond(
196        &mut self,
197        contract_decls: &'hir [rustc_hir::Stmt<'hir>],
198        precond: Option<rustc_hir::Stmt<'hir>>,
199        postcond_checker: &'hir rustc_hir::Expr<'hir>,
200    ) -> &'hir rustc_hir::Expr<'hir> {
201        let stmts = self
202            .arena
203            .alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain(precond.into_iter()));
204        let span = match precond {
205            Some(precond) => precond.span,
206            None => postcond_checker.span,
207        };
208
209        let postcond_checker = self.arena.alloc(self.expr_enum_variant_lang_item(
210            postcond_checker.span,
211            rustc_hir::lang_items::LangItem::OptionSome,
212            &*arena_vec![self; *postcond_checker],
213        ));
214        let then_block_stmts = self.block_all(span, stmts, Some(postcond_checker));
215        let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
216
217        let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
218            postcond_checker.span,
219            rustc_hir::lang_items::LangItem::OptionNone,
220            Default::default(),
221        ));
222        let else_block = self.block_expr(none_expr);
223        let else_block = self.arena.alloc(self.expr_block(else_block));
224
225        let contract_check = rustc_hir::ExprKind::If(
226            self.arena.alloc(self.expr_bool_literal(span, self.tcx.sess.contract_checks())),
227            then_block,
228            Some(else_block),
229        );
230        self.arena.alloc(self.expr(span, contract_check))
231    }
232
233    fn wrap_body_with_contract_check(
234        &mut self,
235        body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
236        contract_check: &'hir rustc_hir::Expr<'hir>,
237        postcond_span: rustc_span::Span,
238    ) -> &'hir rustc_hir::Block<'hir> {
239        let check_ident: rustc_span::Ident =
240            rustc_span::Ident::from_str_and_span("__ensures_checker", postcond_span);
241        let (check_hir_id, postcond_decl) = {
242            // Set up the postcondition `let` statement.
243            let (checker_pat, check_hir_id) = self.pat_ident_binding_mode_mut(
244                postcond_span,
245                check_ident,
246                rustc_hir::BindingMode::NONE,
247            );
248            (
249                check_hir_id,
250                self.stmt_let_pat(
251                    None,
252                    postcond_span,
253                    Some(contract_check),
254                    self.arena.alloc(checker_pat),
255                    rustc_hir::LocalSource::Contract,
256                ),
257            )
258        };
259
260        // Install contract_ensures so we will intercept `return` statements,
261        // then lower the body.
262        self.contract_ensures = Some((postcond_span, check_ident, check_hir_id));
263        let body = self.arena.alloc(body(self));
264
265        // Finally, inject an ensures check on the implicit return of the body.
266        let body = self.inject_ensures_check(body, postcond_span, check_ident, check_hir_id);
267
268        // Flatten the body into precond, then postcond, then wrapped body.
269        let wrapped_body = self.block_all(
270            body.span,
271            self.arena.alloc_from_iter([postcond_decl].into_iter()),
272            Some(body),
273        );
274        wrapped_body
275    }
276
277    /// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
278    /// a contract ensures clause, if it exists.
279    pub(super) fn checked_return(
280        &mut self,
281        opt_expr: Option<&'hir rustc_hir::Expr<'hir>>,
282    ) -> rustc_hir::ExprKind<'hir> {
283        let checked_ret =
284            if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
285                let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
286                Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
287            } else {
288                opt_expr
289            };
290        rustc_hir::ExprKind::Ret(checked_ret)
291    }
292
293    /// Wraps an expression with a call to the ensures check before it gets returned.
294    pub(super) fn inject_ensures_check(
295        &mut self,
296        expr: &'hir rustc_hir::Expr<'hir>,
297        span: rustc_span::Span,
298        cond_ident: rustc_span::Ident,
299        cond_hir_id: rustc_hir::HirId,
300    ) -> &'hir rustc_hir::Expr<'hir> {
301        // {
302        //     let ret = { body };
303        //
304        //     if contract_checks {
305        //         contract_check_ensures(__postcond, ret)
306        //     } else {
307        //         ret
308        //     }
309        // }
310        let ret_ident: rustc_span::Ident = rustc_span::Ident::from_str_and_span("__ret", span);
311
312        // Set up the return `let` statement.
313        let (ret_pat, ret_hir_id) =
314            self.pat_ident_binding_mode_mut(span, ret_ident, rustc_hir::BindingMode::NONE);
315
316        let ret_stmt = self.stmt_let_pat(
317            None,
318            span,
319            Some(expr),
320            self.arena.alloc(ret_pat),
321            rustc_hir::LocalSource::Contract,
322        );
323
324        let ret = self.expr_ident(span, ret_ident, ret_hir_id);
325
326        let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
327        let contract_check = self.expr_call_lang_item_fn_mut(
328            span,
329            rustc_hir::LangItem::ContractCheckEnsures,
330            arena_vec![self; *cond_fn, *ret],
331        );
332        let contract_check = self.arena.alloc(contract_check);
333        let call_expr = self.block_expr_block(contract_check);
334
335        // same ident can't be used in 2 places, so we create a new one for the
336        // else branch
337        let ret = self.expr_ident(span, ret_ident, ret_hir_id);
338        let ret_block = self.block_expr_block(ret);
339
340        let contracts_enabled: rustc_hir::Expr<'_> =
341            self.expr_bool_literal(span, self.tcx.sess.contract_checks());
342        let contract_check = self.arena.alloc(self.expr(
343            span,
344            rustc_hir::ExprKind::If(
345                self.arena.alloc(contracts_enabled),
346                call_expr,
347                Some(ret_block),
348            ),
349        ));
350
351        let attrs: rustc_ast::AttrVec = thin_vec![self.unreachable_code_attr(span)];
352        self.lower_attrs(contract_check.hir_id, &attrs, span, rustc_hir::Target::Expression);
353
354        let ret_block = self.block_all(span, arena_vec![self; ret_stmt], Some(contract_check));
355        self.arena.alloc(self.expr_block(self.arena.alloc(ret_block)))
356    }
357}