rustc_ast_lowering/
contract.rs

1use thin_vec::thin_vec;
2
3use crate::LoweringContext;
4
5impl<'a, 'hir> LoweringContext<'a, 'hir> {
6    /// Lowered contracts are guarded with the `contract_checks` compiler flag,
7    /// i.e. the flag turns into a boolean guard in the lowered HIR. The reason
8    /// for not eliminating the contract code entirely when the `contract_checks`
9    /// flag is disabled is so that contracts can be type checked, even when
10    /// they are disabled, which avoids them becoming stale (i.e. out of sync
11    /// with the codebase) over time.
12    ///
13    /// The optimiser should be able to eliminate all contract code guarded
14    /// by `if false`, leaving the original body intact when runtime contract
15    /// checks are disabled.
16    pub(super) fn lower_contract(
17        &mut self,
18        body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
19        contract: &rustc_ast::FnContract,
20    ) -> rustc_hir::Expr<'hir> {
21        match (&contract.requires, &contract.ensures) {
22            (Some(req), Some(ens)) => {
23                // Lower the fn contract, which turns:
24                //
25                // { body }
26                //
27                // into:
28                //
29                // let __postcond = if contract_checks {
30                //     contract_check_requires(PRECOND);
31                //     Some(|ret_val| POSTCOND)
32                // } else {
33                //     None
34                // };
35                // {
36                //     let ret = { body };
37                //
38                //     if contract_checks {
39                //         contract_check_ensures(__postcond, ret)
40                //     } else {
41                //         ret
42                //     }
43                // }
44
45                let precond = self.lower_precond(req);
46                let postcond_checker = self.lower_postcond_checker(ens);
47
48                let contract_check =
49                    self.lower_contract_check_with_postcond(Some(precond), postcond_checker);
50
51                let wrapped_body =
52                    self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
53                self.expr_block(wrapped_body)
54            }
55            (None, Some(ens)) => {
56                // Lower the fn contract, which turns:
57                //
58                // { body }
59                //
60                // into:
61                //
62                // let __postcond = if contract_checks {
63                //     Some(|ret_val| POSTCOND)
64                // } else {
65                //     None
66                // };
67                // {
68                //     let ret = { body };
69                //
70                //     if contract_checks {
71                //         contract_check_ensures(__postcond, ret)
72                //     } else {
73                //         ret
74                //     }
75                // }
76
77                let postcond_checker = self.lower_postcond_checker(ens);
78                let contract_check =
79                    self.lower_contract_check_with_postcond(None, postcond_checker);
80
81                let wrapped_body =
82                    self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
83                self.expr_block(wrapped_body)
84            }
85            (Some(req), None) => {
86                // Lower the fn contract, which turns:
87                //
88                // { body }
89                //
90                // into:
91                //
92                // {
93                //      if contracts_checks {
94                //          contract_requires(PRECOND);
95                //      }
96                //      body
97                // }
98                let precond = self.lower_precond(req);
99                let precond_check = self.lower_contract_check_just_precond(precond);
100
101                let body = self.arena.alloc(body(self));
102
103                // Flatten the body into precond check, then body.
104                let wrapped_body = self.block_all(
105                    body.span,
106                    self.arena.alloc_from_iter([precond_check].into_iter()),
107                    Some(body),
108                );
109                self.expr_block(wrapped_body)
110            }
111            (None, None) => body(self),
112        }
113    }
114
115    /// Lower the precondition check intrinsic.
116    fn lower_precond(&mut self, req: &Box<rustc_ast::Expr>) -> rustc_hir::Stmt<'hir> {
117        let lowered_req = self.lower_expr_mut(&req);
118        let req_span = self.mark_span_with_reason(
119            rustc_span::DesugaringKind::Contract,
120            lowered_req.span,
121            None,
122        );
123        let precond = self.expr_call_lang_item_fn_mut(
124            req_span,
125            rustc_hir::LangItem::ContractCheckRequires,
126            &*arena_vec![self; lowered_req],
127        );
128        self.stmt_expr(req.span, precond)
129    }
130
131    fn lower_postcond_checker(
132        &mut self,
133        ens: &Box<rustc_ast::Expr>,
134    ) -> &'hir rustc_hir::Expr<'hir> {
135        let ens_span = self.lower_span(ens.span);
136        let ens_span =
137            self.mark_span_with_reason(rustc_span::DesugaringKind::Contract, ens_span, None);
138        let lowered_ens = self.lower_expr_mut(&ens);
139        self.expr_call_lang_item_fn(
140            ens_span,
141            rustc_hir::LangItem::ContractBuildCheckEnsures,
142            &*arena_vec![self; lowered_ens],
143        )
144    }
145
146    fn lower_contract_check_just_precond(
147        &mut self,
148        precond: rustc_hir::Stmt<'hir>,
149    ) -> rustc_hir::Stmt<'hir> {
150        let stmts = self.arena.alloc_from_iter([precond].into_iter());
151
152        let then_block_stmts = self.block_all(precond.span, stmts, None);
153        let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
154
155        let precond_check = rustc_hir::ExprKind::If(
156            self.arena.alloc(self.expr_bool_literal(precond.span, self.tcx.sess.contract_checks())),
157            then_block,
158            None,
159        );
160
161        let precond_check = self.expr(precond.span, precond_check);
162        self.stmt_expr(precond.span, precond_check)
163    }
164
165    fn lower_contract_check_with_postcond(
166        &mut self,
167        precond: Option<rustc_hir::Stmt<'hir>>,
168        postcond_checker: &'hir rustc_hir::Expr<'hir>,
169    ) -> &'hir rustc_hir::Expr<'hir> {
170        let stmts = self.arena.alloc_from_iter(precond.into_iter());
171        let span = match precond {
172            Some(precond) => precond.span,
173            None => postcond_checker.span,
174        };
175
176        let postcond_checker = self.arena.alloc(self.expr_enum_variant_lang_item(
177            postcond_checker.span,
178            rustc_hir::lang_items::LangItem::OptionSome,
179            &*arena_vec![self; *postcond_checker],
180        ));
181        let then_block_stmts = self.block_all(span, stmts, Some(postcond_checker));
182        let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
183
184        let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
185            postcond_checker.span,
186            rustc_hir::lang_items::LangItem::OptionNone,
187            Default::default(),
188        ));
189        let else_block = self.block_expr(none_expr);
190        let else_block = self.arena.alloc(self.expr_block(else_block));
191
192        let contract_check = rustc_hir::ExprKind::If(
193            self.arena.alloc(self.expr_bool_literal(span, self.tcx.sess.contract_checks())),
194            then_block,
195            Some(else_block),
196        );
197        self.arena.alloc(self.expr(span, contract_check))
198    }
199
200    fn wrap_body_with_contract_check(
201        &mut self,
202        body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
203        contract_check: &'hir rustc_hir::Expr<'hir>,
204        postcond_span: rustc_span::Span,
205    ) -> &'hir rustc_hir::Block<'hir> {
206        let check_ident: rustc_span::Ident =
207            rustc_span::Ident::from_str_and_span("__ensures_checker", postcond_span);
208        let (check_hir_id, postcond_decl) = {
209            // Set up the postcondition `let` statement.
210            let (checker_pat, check_hir_id) = self.pat_ident_binding_mode_mut(
211                postcond_span,
212                check_ident,
213                rustc_hir::BindingMode::NONE,
214            );
215            (
216                check_hir_id,
217                self.stmt_let_pat(
218                    None,
219                    postcond_span,
220                    Some(contract_check),
221                    self.arena.alloc(checker_pat),
222                    rustc_hir::LocalSource::Contract,
223                ),
224            )
225        };
226
227        // Install contract_ensures so we will intercept `return` statements,
228        // then lower the body.
229        self.contract_ensures = Some((postcond_span, check_ident, check_hir_id));
230        let body = self.arena.alloc(body(self));
231
232        // Finally, inject an ensures check on the implicit return of the body.
233        let body = self.inject_ensures_check(body, postcond_span, check_ident, check_hir_id);
234
235        // Flatten the body into precond, then postcond, then wrapped body.
236        let wrapped_body = self.block_all(
237            body.span,
238            self.arena.alloc_from_iter([postcond_decl].into_iter()),
239            Some(body),
240        );
241        wrapped_body
242    }
243
244    /// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
245    /// a contract ensures clause, if it exists.
246    pub(super) fn checked_return(
247        &mut self,
248        opt_expr: Option<&'hir rustc_hir::Expr<'hir>>,
249    ) -> rustc_hir::ExprKind<'hir> {
250        let checked_ret =
251            if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
252                let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
253                Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
254            } else {
255                opt_expr
256            };
257        rustc_hir::ExprKind::Ret(checked_ret)
258    }
259
260    /// Wraps an expression with a call to the ensures check before it gets returned.
261    pub(super) fn inject_ensures_check(
262        &mut self,
263        expr: &'hir rustc_hir::Expr<'hir>,
264        span: rustc_span::Span,
265        cond_ident: rustc_span::Ident,
266        cond_hir_id: rustc_hir::HirId,
267    ) -> &'hir rustc_hir::Expr<'hir> {
268        // {
269        //     let ret = { body };
270        //
271        //     if contract_checks {
272        //         contract_check_ensures(__postcond, ret)
273        //     } else {
274        //         ret
275        //     }
276        // }
277        let ret_ident: rustc_span::Ident = rustc_span::Ident::from_str_and_span("__ret", span);
278
279        // Set up the return `let` statement.
280        let (ret_pat, ret_hir_id) =
281            self.pat_ident_binding_mode_mut(span, ret_ident, rustc_hir::BindingMode::NONE);
282
283        let ret_stmt = self.stmt_let_pat(
284            None,
285            span,
286            Some(expr),
287            self.arena.alloc(ret_pat),
288            rustc_hir::LocalSource::Contract,
289        );
290
291        let ret = self.expr_ident(span, ret_ident, ret_hir_id);
292
293        let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
294        let contract_check = self.expr_call_lang_item_fn_mut(
295            span,
296            rustc_hir::LangItem::ContractCheckEnsures,
297            arena_vec![self; *cond_fn, *ret],
298        );
299        let contract_check = self.arena.alloc(contract_check);
300        let call_expr = self.block_expr_block(contract_check);
301
302        // same ident can't be used in 2 places, so we create a new one for the
303        // else branch
304        let ret = self.expr_ident(span, ret_ident, ret_hir_id);
305        let ret_block = self.block_expr_block(ret);
306
307        let contracts_enabled: rustc_hir::Expr<'_> =
308            self.expr_bool_literal(span, self.tcx.sess.contract_checks());
309        let contract_check = self.arena.alloc(self.expr(
310            span,
311            rustc_hir::ExprKind::If(
312                self.arena.alloc(contracts_enabled),
313                call_expr,
314                Some(ret_block),
315            ),
316        ));
317
318        let attrs: rustc_ast::AttrVec = thin_vec![self.unreachable_code_attr(span)];
319        self.lower_attrs(contract_check.hir_id, &attrs, span, rustc_hir::Target::Expression);
320
321        let ret_block = self.block_all(span, arena_vec![self; ret_stmt], Some(contract_check));
322        self.arena.alloc(self.expr_block(self.arena.alloc(ret_block)))
323    }
324}