rustc_builtin_macros/assert/
context.rs

1use rustc_ast::ptr::P;
2use rustc_ast::token::{self, Delimiter, IdentIsRaw};
3use rustc_ast::tokenstream::{DelimSpan, TokenStream, TokenTree};
4use rustc_ast::{
5    BinOpKind, BorrowKind, DUMMY_NODE_ID, DelimArgs, Expr, ExprKind, ItemKind, MacCall, MethodCall,
6    Mutability, Path, PathSegment, Stmt, StructRest, UnOp, UseTree, UseTreeKind,
7};
8use rustc_ast_pretty::pprust;
9use rustc_data_structures::fx::FxHashSet;
10use rustc_expand::base::ExtCtxt;
11use rustc_span::{Ident, Span, Symbol, sym};
12use thin_vec::{ThinVec, thin_vec};
13
14pub(super) struct Context<'cx, 'a> {
15    // An optimization.
16    //
17    // Elements that aren't consumed (PartialEq, PartialOrd, ...) can be copied **after** the
18    // `assert!` expression fails rather than copied on-the-fly.
19    best_case_captures: Vec<Stmt>,
20    // Top-level `let captureN = Capture::new()` statements
21    capture_decls: Vec<Capture>,
22    cx: &'cx ExtCtxt<'a>,
23    // Formatting string used for debugging
24    fmt_string: String,
25    // If the current expression being visited consumes itself. Used to construct
26    // `best_case_captures`.
27    is_consumed: bool,
28    // Top-level `let __local_bindN = &expr` statements
29    local_bind_decls: Vec<Stmt>,
30    // Used to avoid capturing duplicated paths
31    //
32    // ```rust
33    // let a = 1i32;
34    // assert!(add(a, a) == 3);
35    // ```
36    paths: FxHashSet<Ident>,
37    span: Span,
38}
39
40impl<'cx, 'a> Context<'cx, 'a> {
41    pub(super) fn new(cx: &'cx ExtCtxt<'a>, span: Span) -> Self {
42        Self {
43            best_case_captures: <_>::default(),
44            capture_decls: <_>::default(),
45            cx,
46            fmt_string: <_>::default(),
47            is_consumed: true,
48            local_bind_decls: <_>::default(),
49            paths: <_>::default(),
50            span,
51        }
52    }
53
54    /// Builds the whole `assert!` expression. For example, `let elem = 1; assert!(elem == 1);` expands to:
55    ///
56    /// ```rust
57    /// let elem = 1;
58    /// {
59    ///   #[allow(unused_imports)]
60    ///   use ::core::asserting::{TryCaptureGeneric, TryCapturePrintable};
61    ///   let mut __capture0 = ::core::asserting::Capture::new();
62    ///   let __local_bind0 = &elem;
63    ///   if !(
64    ///     *{
65    ///       (&::core::asserting::Wrapper(__local_bind0)).try_capture(&mut __capture0);
66    ///       __local_bind0
67    ///     } == 1
68    ///   ) {
69    ///     panic!("Assertion failed: elem == 1\nWith captures:\n  elem = {:?}", __capture0)
70    ///   }
71    /// }
72    /// ```
73    pub(super) fn build(mut self, mut cond_expr: P<Expr>, panic_path: Path) -> P<Expr> {
74        let expr_str = pprust::expr_to_string(&cond_expr);
75        self.manage_cond_expr(&mut cond_expr);
76        let initial_imports = self.build_initial_imports();
77        let panic = self.build_panic(&expr_str, panic_path);
78        let cond_expr_with_unlikely = self.build_unlikely(cond_expr);
79
80        let Self { best_case_captures, capture_decls, cx, local_bind_decls, span, .. } = self;
81
82        let mut assert_then_stmts = ThinVec::with_capacity(2);
83        assert_then_stmts.extend(best_case_captures);
84        assert_then_stmts.push(self.cx.stmt_expr(panic));
85        let assert_then = self.cx.block(span, assert_then_stmts);
86
87        let mut stmts = ThinVec::with_capacity(4);
88        stmts.push(initial_imports);
89        stmts.extend(capture_decls.into_iter().map(|c| c.decl));
90        stmts.extend(local_bind_decls);
91        stmts.push(
92            cx.stmt_expr(cx.expr(span, ExprKind::If(cond_expr_with_unlikely, assert_then, None))),
93        );
94        cx.expr_block(cx.block(span, stmts))
95    }
96
97    /// Initial **trait** imports
98    ///
99    /// use ::core::asserting::{ ... };
100    fn build_initial_imports(&self) -> Stmt {
101        let nested_tree = |this: &Self, sym| {
102            (
103                UseTree {
104                    prefix: this.cx.path(this.span, vec![Ident::with_dummy_span(sym)]),
105                    kind: UseTreeKind::Simple(None),
106                    span: this.span,
107                },
108                DUMMY_NODE_ID,
109            )
110        };
111        self.cx.stmt_item(
112            self.span,
113            self.cx.item(
114                self.span,
115                Ident::empty(),
116                thin_vec![self.cx.attr_nested_word(sym::allow, sym::unused_imports, self.span)],
117                ItemKind::Use(UseTree {
118                    prefix: self.cx.path(self.span, self.cx.std_path(&[sym::asserting])),
119                    kind: UseTreeKind::Nested {
120                        items: thin_vec![
121                            nested_tree(self, sym::TryCaptureGeneric),
122                            nested_tree(self, sym::TryCapturePrintable),
123                        ],
124                        span: self.span,
125                    },
126                    span: self.span,
127                }),
128            ),
129        )
130    }
131
132    /// Takes the conditional expression of `assert!` and then wraps it inside `unlikely`
133    fn build_unlikely(&self, cond_expr: P<Expr>) -> P<Expr> {
134        let unlikely_path = self.cx.std_path(&[sym::intrinsics, sym::unlikely]);
135        self.cx.expr_call(
136            self.span,
137            self.cx.expr_path(self.cx.path(self.span, unlikely_path)),
138            thin_vec![self.cx.expr(self.span, ExprKind::Unary(UnOp::Not, cond_expr))],
139        )
140    }
141
142    /// The necessary custom `panic!(...)` expression.
143    ///
144    /// panic!(
145    ///     "Assertion failed: ... \n With expansion: ...",
146    ///     __capture0,
147    ///     ...
148    /// );
149    fn build_panic(&self, expr_str: &str, panic_path: Path) -> P<Expr> {
150        let escaped_expr_str = escape_to_fmt(expr_str);
151        let initial = [
152            TokenTree::token_joint(
153                token::Literal(token::Lit {
154                    kind: token::LitKind::Str,
155                    symbol: Symbol::intern(&if self.fmt_string.is_empty() {
156                        format!("Assertion failed: {escaped_expr_str}")
157                    } else {
158                        format!(
159                            "Assertion failed: {escaped_expr_str}\nWith captures:\n{}",
160                            &self.fmt_string
161                        )
162                    }),
163                    suffix: None,
164                }),
165                self.span,
166            ),
167            TokenTree::token_alone(token::Comma, self.span),
168        ];
169        let captures = self.capture_decls.iter().flat_map(|cap| {
170            [
171                TokenTree::token_joint(
172                    token::Ident(cap.ident.name, IdentIsRaw::No),
173                    cap.ident.span,
174                ),
175                TokenTree::token_alone(token::Comma, self.span),
176            ]
177        });
178        self.cx.expr(
179            self.span,
180            ExprKind::MacCall(P(MacCall {
181                path: panic_path,
182                args: P(DelimArgs {
183                    dspan: DelimSpan::from_single(self.span),
184                    delim: Delimiter::Parenthesis,
185                    tokens: initial.into_iter().chain(captures).collect::<TokenStream>(),
186                }),
187            })),
188        )
189    }
190
191    /// Recursive function called until `cond_expr` and `fmt_str` are fully modified.
192    ///
193    /// See [Self::manage_initial_capture] and [Self::manage_try_capture]
194    fn manage_cond_expr(&mut self, expr: &mut P<Expr>) {
195        match &mut expr.kind {
196            ExprKind::AddrOf(_, mutability, local_expr) => {
197                self.with_is_consumed_management(matches!(mutability, Mutability::Mut), |this| {
198                    this.manage_cond_expr(local_expr)
199                });
200            }
201            ExprKind::Array(local_exprs) => {
202                for local_expr in local_exprs {
203                    self.manage_cond_expr(local_expr);
204                }
205            }
206            ExprKind::Binary(op, lhs, rhs) => {
207                self.with_is_consumed_management(
208                    matches!(
209                        op.node,
210                        BinOpKind::Add
211                            | BinOpKind::And
212                            | BinOpKind::BitAnd
213                            | BinOpKind::BitOr
214                            | BinOpKind::BitXor
215                            | BinOpKind::Div
216                            | BinOpKind::Mul
217                            | BinOpKind::Or
218                            | BinOpKind::Rem
219                            | BinOpKind::Shl
220                            | BinOpKind::Shr
221                            | BinOpKind::Sub
222                    ),
223                    |this| {
224                        this.manage_cond_expr(lhs);
225                        this.manage_cond_expr(rhs);
226                    },
227                );
228            }
229            ExprKind::Call(_, local_exprs) => {
230                for local_expr in local_exprs {
231                    self.manage_cond_expr(local_expr);
232                }
233            }
234            ExprKind::Cast(local_expr, _) => {
235                self.manage_cond_expr(local_expr);
236            }
237            ExprKind::If(local_expr, _, _) => {
238                self.manage_cond_expr(local_expr);
239            }
240            ExprKind::Index(prefix, suffix, _) => {
241                self.manage_cond_expr(prefix);
242                self.manage_cond_expr(suffix);
243            }
244            ExprKind::Let(_, local_expr, _, _) => {
245                self.manage_cond_expr(local_expr);
246            }
247            ExprKind::Match(local_expr, ..) => {
248                self.manage_cond_expr(local_expr);
249            }
250            ExprKind::MethodCall(call) => {
251                for arg in &mut call.args {
252                    self.manage_cond_expr(arg);
253                }
254            }
255            ExprKind::Path(_, Path { segments, .. }) if let [path_segment] = &segments[..] => {
256                let path_ident = path_segment.ident;
257                self.manage_initial_capture(expr, path_ident);
258            }
259            ExprKind::Paren(local_expr) => {
260                self.manage_cond_expr(local_expr);
261            }
262            ExprKind::Range(prefix, suffix, _) => {
263                if let Some(elem) = prefix {
264                    self.manage_cond_expr(elem);
265                }
266                if let Some(elem) = suffix {
267                    self.manage_cond_expr(elem);
268                }
269            }
270            ExprKind::Repeat(local_expr, elem) => {
271                self.manage_cond_expr(local_expr);
272                self.manage_cond_expr(&mut elem.value);
273            }
274            ExprKind::Struct(elem) => {
275                for field in &mut elem.fields {
276                    self.manage_cond_expr(&mut field.expr);
277                }
278                if let StructRest::Base(local_expr) = &mut elem.rest {
279                    self.manage_cond_expr(local_expr);
280                }
281            }
282            ExprKind::Tup(local_exprs) => {
283                for local_expr in local_exprs {
284                    self.manage_cond_expr(local_expr);
285                }
286            }
287            ExprKind::Unary(un_op, local_expr) => {
288                self.with_is_consumed_management(matches!(un_op, UnOp::Neg | UnOp::Not), |this| {
289                    this.manage_cond_expr(local_expr)
290                });
291            }
292            // Expressions that are not worth or can not be captured.
293            //
294            // Full list instead of `_` to catch possible future inclusions and to
295            // sync with the `rfc-2011-nicer-assert-messages/all-expr-kinds.rs` test.
296            ExprKind::Assign(_, _, _)
297            | ExprKind::AssignOp(_, _, _)
298            | ExprKind::Gen(_, _, _, _)
299            | ExprKind::Await(_, _)
300            | ExprKind::Block(_, _)
301            | ExprKind::Break(_, _)
302            | ExprKind::Closure(_)
303            | ExprKind::ConstBlock(_)
304            | ExprKind::Continue(_)
305            | ExprKind::Dummy
306            | ExprKind::Err(_)
307            | ExprKind::Field(_, _)
308            | ExprKind::ForLoop { .. }
309            | ExprKind::FormatArgs(_)
310            | ExprKind::IncludedBytes(..)
311            | ExprKind::InlineAsm(_)
312            | ExprKind::Lit(_)
313            | ExprKind::Loop(_, _, _)
314            | ExprKind::MacCall(_)
315            | ExprKind::OffsetOf(_, _)
316            | ExprKind::Path(_, _)
317            | ExprKind::Ret(_)
318            | ExprKind::Try(_)
319            | ExprKind::TryBlock(_)
320            | ExprKind::Type(_, _)
321            | ExprKind::Underscore
322            | ExprKind::While(_, _, _)
323            | ExprKind::Yeet(_)
324            | ExprKind::Become(_)
325            | ExprKind::Yield(_)
326            | ExprKind::UnsafeBinderCast(..) => {}
327        }
328    }
329
330    /// Pushes the top-level declarations and modifies `expr` to try capturing variables.
331    ///
332    /// `fmt_str`, the formatting string used for debugging, is constructed to show possible
333    /// captured variables.
334    fn manage_initial_capture(&mut self, expr: &mut P<Expr>, path_ident: Ident) {
335        if self.paths.contains(&path_ident) {
336            return;
337        } else {
338            self.fmt_string.push_str("  ");
339            self.fmt_string.push_str(path_ident.as_str());
340            self.fmt_string.push_str(" = {:?}\n");
341            let _ = self.paths.insert(path_ident);
342        }
343        let curr_capture_idx = self.capture_decls.len();
344        let capture_string = format!("__capture{curr_capture_idx}");
345        let ident = Ident::new(Symbol::intern(&capture_string), self.span);
346        let init_std_path = self.cx.std_path(&[sym::asserting, sym::Capture, sym::new]);
347        let init = self.cx.expr_call(
348            self.span,
349            self.cx.expr_path(self.cx.path(self.span, init_std_path)),
350            ThinVec::new(),
351        );
352        let capture = Capture { decl: self.cx.stmt_let(self.span, true, ident, init), ident };
353        self.capture_decls.push(capture);
354        self.manage_try_capture(ident, curr_capture_idx, expr);
355    }
356
357    /// Tries to copy `__local_bindN` into `__captureN`.
358    ///
359    /// *{
360    ///    (&Wrapper(__local_bindN)).try_capture(&mut __captureN);
361    ///    __local_bindN
362    /// }
363    fn manage_try_capture(&mut self, capture: Ident, curr_capture_idx: usize, expr: &mut P<Expr>) {
364        let local_bind_string = format!("__local_bind{curr_capture_idx}");
365        let local_bind = Ident::new(Symbol::intern(&local_bind_string), self.span);
366        self.local_bind_decls.push(self.cx.stmt_let(
367            self.span,
368            false,
369            local_bind,
370            self.cx.expr_addr_of(self.span, expr.clone()),
371        ));
372        let wrapper = self.cx.expr_call(
373            self.span,
374            self.cx.expr_path(
375                self.cx.path(self.span, self.cx.std_path(&[sym::asserting, sym::Wrapper])),
376            ),
377            thin_vec![self.cx.expr_path(Path::from_ident(local_bind))],
378        );
379        let try_capture_call = self
380            .cx
381            .stmt_expr(expr_method_call(
382                self.cx,
383                PathSegment {
384                    args: None,
385                    id: DUMMY_NODE_ID,
386                    ident: Ident::new(sym::try_capture, self.span),
387                },
388                expr_paren(self.cx, self.span, self.cx.expr_addr_of(self.span, wrapper)),
389                thin_vec![expr_addr_of_mut(
390                    self.cx,
391                    self.span,
392                    self.cx.expr_path(Path::from_ident(capture)),
393                )],
394                self.span,
395            ))
396            .add_trailing_semicolon();
397        let local_bind_path = self.cx.expr_path(Path::from_ident(local_bind));
398        let rslt = if self.is_consumed {
399            let ret = self.cx.stmt_expr(local_bind_path);
400            self.cx.expr_block(self.cx.block(self.span, thin_vec![try_capture_call, ret]))
401        } else {
402            self.best_case_captures.push(try_capture_call);
403            local_bind_path
404        };
405        *expr = self.cx.expr_deref(self.span, rslt);
406    }
407
408    // Calls `f` with the internal `is_consumed` set to `curr_is_consumed` and then
409    // sets the internal `is_consumed` back to its original value.
410    fn with_is_consumed_management(&mut self, curr_is_consumed: bool, f: impl FnOnce(&mut Self)) {
411        let prev_is_consumed = self.is_consumed;
412        self.is_consumed = curr_is_consumed;
413        f(self);
414        self.is_consumed = prev_is_consumed;
415    }
416}
417
418/// Information about a captured element.
419#[derive(Debug)]
420struct Capture {
421    // Generated indexed `Capture` statement.
422    //
423    // `let __capture{} = Capture::new();`
424    decl: Stmt,
425    // The name of the generated indexed `Capture` variable.
426    //
427    // `__capture{}`
428    ident: Ident,
429}
430
431/// Escapes to use as a formatting string.
432fn escape_to_fmt(s: &str) -> String {
433    let mut rslt = String::with_capacity(s.len());
434    for c in s.chars() {
435        rslt.extend(c.escape_debug());
436        match c {
437            '{' | '}' => rslt.push(c),
438            _ => {}
439        }
440    }
441    rslt
442}
443
444fn expr_addr_of_mut(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
445    cx.expr(sp, ExprKind::AddrOf(BorrowKind::Ref, Mutability::Mut, e))
446}
447
448fn expr_method_call(
449    cx: &ExtCtxt<'_>,
450    seg: PathSegment,
451    receiver: P<Expr>,
452    args: ThinVec<P<Expr>>,
453    span: Span,
454) -> P<Expr> {
455    cx.expr(span, ExprKind::MethodCall(Box::new(MethodCall { seg, receiver, args, span })))
456}
457
458fn expr_paren(cx: &ExtCtxt<'_>, sp: Span, e: P<Expr>) -> P<Expr> {
459    cx.expr(sp, ExprKind::Paren(e))
460}