rustc_ast_lowering/
contract.rs

1use std::sync::Arc;
2
3use thin_vec::thin_vec;
4
5use crate::LoweringContext;
6
7impl<'a, 'hir> LoweringContext<'a, 'hir> {
8    /// Lowered contracts are guarded with the `contract_checks` compiler flag,
9    /// i.e. the flag turns into a boolean guard in the lowered HIR. The reason
10    /// for not eliminating the contract code entirely when the `contract_checks`
11    /// flag is disabled is so that contracts can be type checked, even when
12    /// they are disabled, which avoids them becoming stale (i.e. out of sync
13    /// with the codebase) over time.
14    ///
15    /// The optimiser should be able to eliminate all contract code guarded
16    /// by `if false`, leaving the original body intact when runtime contract
17    /// checks are disabled.
18    pub(super) fn lower_contract(
19        &mut self,
20        body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
21        contract: &rustc_ast::FnContract,
22    ) -> rustc_hir::Expr<'hir> {
23        // The order in which things are lowered is important! I.e to
24        // refer to variables in contract_decls from postcond/precond,
25        // we must lower it first!
26        let contract_decls = self.lower_stmts(&contract.declarations).0;
27
28        match (&contract.requires, &contract.ensures) {
29            (Some(req), Some(ens)) => {
30                // Lower the fn contract, which turns:
31                //
32                // { body }
33                //
34                // into:
35                //
36                // let __postcond = if contract_checks {
37                //     CONTRACT_DECLARATIONS;
38                //     contract_check_requires(PRECOND);
39                //     Some(|ret_val| POSTCOND)
40                // } else {
41                //     None
42                // };
43                // {
44                //     let ret = { body };
45                //
46                //     if contract_checks {
47                //         contract_check_ensures(__postcond, ret)
48                //     } else {
49                //         ret
50                //     }
51                // }
52
53                let precond = self.lower_precond(req);
54                let postcond_checker = self.lower_postcond_checker(ens);
55
56                let contract_check = self.lower_contract_check_with_postcond(
57                    contract_decls,
58                    Some(precond),
59                    postcond_checker,
60                );
61
62                let wrapped_body =
63                    self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
64                self.expr_block(wrapped_body)
65            }
66            (None, Some(ens)) => {
67                // Lower the fn contract, which turns:
68                //
69                // { body }
70                //
71                // into:
72                //
73                // let __postcond = if contract_checks {
74                //     Some(|ret_val| POSTCOND)
75                // } else {
76                //     None
77                // };
78                // {
79                //     let ret = { body };
80                //
81                //     if contract_checks {
82                //         CONTRACT_DECLARATIONS;
83                //         contract_check_ensures(__postcond, ret)
84                //     } else {
85                //         ret
86                //     }
87                // }
88                let postcond_checker = self.lower_postcond_checker(ens);
89                let contract_check =
90                    self.lower_contract_check_with_postcond(contract_decls, None, postcond_checker);
91
92                let wrapped_body =
93                    self.wrap_body_with_contract_check(body, contract_check, postcond_checker.span);
94                self.expr_block(wrapped_body)
95            }
96            (Some(req), None) => {
97                // Lower the fn contract, which turns:
98                //
99                // { body }
100                //
101                // into:
102                //
103                // {
104                //      if contracts_checks {
105                //          CONTRACT_DECLARATIONS;
106                //          contract_requires(PRECOND);
107                //      }
108                //      body
109                // }
110                let precond = self.lower_precond(req);
111                let precond_check = self.lower_contract_check_just_precond(contract_decls, precond);
112
113                let body = self.arena.alloc(body(self));
114
115                // Flatten the body into precond check, then body.
116                let wrapped_body = self.block_all(
117                    body.span,
118                    self.arena.alloc_from_iter([precond_check].into_iter()),
119                    Some(body),
120                );
121                self.expr_block(wrapped_body)
122            }
123            (None, None) => body(self),
124        }
125    }
126
127    /// Lower the precondition check intrinsic.
128    fn lower_precond(&mut self, req: &Box<rustc_ast::Expr>) -> rustc_hir::Stmt<'hir> {
129        let lowered_req = self.lower_expr_mut(&req);
130        let req_span = self.mark_span_with_reason(
131            rustc_span::DesugaringKind::Contract,
132            lowered_req.span,
133            Some(Arc::clone(&self.allow_contracts)),
134        );
135        let precond = self.expr_call_lang_item_fn_mut(
136            req_span,
137            rustc_hir::LangItem::ContractCheckRequires,
138            &*arena_vec![self; lowered_req],
139        );
140        self.stmt_expr(req.span, precond)
141    }
142
143    fn lower_postcond_checker(
144        &mut self,
145        ens: &Box<rustc_ast::Expr>,
146    ) -> &'hir rustc_hir::Expr<'hir> {
147        let ens_span = self.lower_span(ens.span);
148        let ens_span = self.mark_span_with_reason(
149            rustc_span::DesugaringKind::Contract,
150            ens_span,
151            Some(Arc::clone(&self.allow_contracts)),
152        );
153        let lowered_ens = self.lower_expr_mut(&ens);
154        self.expr_call_lang_item_fn(
155            ens_span,
156            rustc_hir::LangItem::ContractBuildCheckEnsures,
157            &*arena_vec![self; lowered_ens],
158        )
159    }
160
161    fn lower_contract_check_just_precond(
162        &mut self,
163        contract_decls: &'hir [rustc_hir::Stmt<'hir>],
164        precond: rustc_hir::Stmt<'hir>,
165    ) -> rustc_hir::Stmt<'hir> {
166        let stmts = self
167            .arena
168            .alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain([precond].into_iter()));
169
170        let then_block_stmts = self.block_all(precond.span, stmts, None);
171        let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
172
173        let precond_check = rustc_hir::ExprKind::If(
174            self.arena.alloc(self.expr_bool_literal(precond.span, self.tcx.sess.contract_checks())),
175            then_block,
176            None,
177        );
178
179        let precond_check = self.expr(precond.span, precond_check);
180        self.stmt_expr(precond.span, precond_check)
181    }
182
183    fn lower_contract_check_with_postcond(
184        &mut self,
185        contract_decls: &'hir [rustc_hir::Stmt<'hir>],
186        precond: Option<rustc_hir::Stmt<'hir>>,
187        postcond_checker: &'hir rustc_hir::Expr<'hir>,
188    ) -> &'hir rustc_hir::Expr<'hir> {
189        let stmts = self
190            .arena
191            .alloc_from_iter(contract_decls.into_iter().map(|d| *d).chain(precond.into_iter()));
192        let span = match precond {
193            Some(precond) => precond.span,
194            None => postcond_checker.span,
195        };
196
197        let postcond_checker = self.arena.alloc(self.expr_enum_variant_lang_item(
198            postcond_checker.span,
199            rustc_hir::lang_items::LangItem::OptionSome,
200            &*arena_vec![self; *postcond_checker],
201        ));
202        let then_block_stmts = self.block_all(span, stmts, Some(postcond_checker));
203        let then_block = self.arena.alloc(self.expr_block(&then_block_stmts));
204
205        let none_expr = self.arena.alloc(self.expr_enum_variant_lang_item(
206            postcond_checker.span,
207            rustc_hir::lang_items::LangItem::OptionNone,
208            Default::default(),
209        ));
210        let else_block = self.block_expr(none_expr);
211        let else_block = self.arena.alloc(self.expr_block(else_block));
212
213        let contract_check = rustc_hir::ExprKind::If(
214            self.arena.alloc(self.expr_bool_literal(span, self.tcx.sess.contract_checks())),
215            then_block,
216            Some(else_block),
217        );
218        self.arena.alloc(self.expr(span, contract_check))
219    }
220
221    fn wrap_body_with_contract_check(
222        &mut self,
223        body: impl FnOnce(&mut Self) -> rustc_hir::Expr<'hir>,
224        contract_check: &'hir rustc_hir::Expr<'hir>,
225        postcond_span: rustc_span::Span,
226    ) -> &'hir rustc_hir::Block<'hir> {
227        let check_ident: rustc_span::Ident =
228            rustc_span::Ident::from_str_and_span("__ensures_checker", postcond_span);
229        let (check_hir_id, postcond_decl) = {
230            // Set up the postcondition `let` statement.
231            let (checker_pat, check_hir_id) = self.pat_ident_binding_mode_mut(
232                postcond_span,
233                check_ident,
234                rustc_hir::BindingMode::NONE,
235            );
236            (
237                check_hir_id,
238                self.stmt_let_pat(
239                    None,
240                    postcond_span,
241                    Some(contract_check),
242                    self.arena.alloc(checker_pat),
243                    rustc_hir::LocalSource::Contract,
244                ),
245            )
246        };
247
248        // Install contract_ensures so we will intercept `return` statements,
249        // then lower the body.
250        self.contract_ensures = Some((postcond_span, check_ident, check_hir_id));
251        let body = self.arena.alloc(body(self));
252
253        // Finally, inject an ensures check on the implicit return of the body.
254        let body = self.inject_ensures_check(body, postcond_span, check_ident, check_hir_id);
255
256        // Flatten the body into precond, then postcond, then wrapped body.
257        let wrapped_body = self.block_all(
258            body.span,
259            self.arena.alloc_from_iter([postcond_decl].into_iter()),
260            Some(body),
261        );
262        wrapped_body
263    }
264
265    /// Create an `ExprKind::Ret` that is optionally wrapped by a call to check
266    /// a contract ensures clause, if it exists.
267    pub(super) fn checked_return(
268        &mut self,
269        opt_expr: Option<&'hir rustc_hir::Expr<'hir>>,
270    ) -> rustc_hir::ExprKind<'hir> {
271        let checked_ret =
272            if let Some((check_span, check_ident, check_hir_id)) = self.contract_ensures {
273                let expr = opt_expr.unwrap_or_else(|| self.expr_unit(check_span));
274                Some(self.inject_ensures_check(expr, check_span, check_ident, check_hir_id))
275            } else {
276                opt_expr
277            };
278        rustc_hir::ExprKind::Ret(checked_ret)
279    }
280
281    /// Wraps an expression with a call to the ensures check before it gets returned.
282    pub(super) fn inject_ensures_check(
283        &mut self,
284        expr: &'hir rustc_hir::Expr<'hir>,
285        span: rustc_span::Span,
286        cond_ident: rustc_span::Ident,
287        cond_hir_id: rustc_hir::HirId,
288    ) -> &'hir rustc_hir::Expr<'hir> {
289        // {
290        //     let ret = { body };
291        //
292        //     if contract_checks {
293        //         contract_check_ensures(__postcond, ret)
294        //     } else {
295        //         ret
296        //     }
297        // }
298        let ret_ident: rustc_span::Ident = rustc_span::Ident::from_str_and_span("__ret", span);
299
300        // Set up the return `let` statement.
301        let (ret_pat, ret_hir_id) =
302            self.pat_ident_binding_mode_mut(span, ret_ident, rustc_hir::BindingMode::NONE);
303
304        let ret_stmt = self.stmt_let_pat(
305            None,
306            span,
307            Some(expr),
308            self.arena.alloc(ret_pat),
309            rustc_hir::LocalSource::Contract,
310        );
311
312        let ret = self.expr_ident(span, ret_ident, ret_hir_id);
313
314        let cond_fn = self.expr_ident(span, cond_ident, cond_hir_id);
315        let contract_check = self.expr_call_lang_item_fn_mut(
316            span,
317            rustc_hir::LangItem::ContractCheckEnsures,
318            arena_vec![self; *cond_fn, *ret],
319        );
320        let contract_check = self.arena.alloc(contract_check);
321        let call_expr = self.block_expr_block(contract_check);
322
323        // same ident can't be used in 2 places, so we create a new one for the
324        // else branch
325        let ret = self.expr_ident(span, ret_ident, ret_hir_id);
326        let ret_block = self.block_expr_block(ret);
327
328        let contracts_enabled: rustc_hir::Expr<'_> =
329            self.expr_bool_literal(span, self.tcx.sess.contract_checks());
330        let contract_check = self.arena.alloc(self.expr(
331            span,
332            rustc_hir::ExprKind::If(
333                self.arena.alloc(contracts_enabled),
334                call_expr,
335                Some(ret_block),
336            ),
337        ));
338
339        let attrs: rustc_ast::AttrVec = thin_vec![self.unreachable_code_attr(span)];
340        self.lower_attrs(contract_check.hir_id, &attrs, span, rustc_hir::Target::Expression);
341
342        let ret_block = self.block_all(span, arena_vec![self; ret_stmt], Some(contract_check));
343        self.arena.alloc(self.expr_block(self.arena.alloc(ret_block)))
344    }
345}