rustc_builtin_macros/
contracts.rs

1use rustc_ast::token;
2use rustc_ast::tokenstream::{DelimSpacing, DelimSpan, Spacing, TokenStream, TokenTree};
3use rustc_errors::ErrorGuaranteed;
4use rustc_expand::base::{AttrProcMacro, ExtCtxt};
5use rustc_span::Span;
6use rustc_span::symbol::{Ident, Symbol, kw};
7
8pub(crate) struct ExpandRequires;
9
10pub(crate) struct ExpandEnsures;
11
12impl AttrProcMacro for ExpandRequires {
13    fn expand<'cx>(
14        &self,
15        ecx: &'cx mut ExtCtxt<'_>,
16        span: Span,
17        annotation: TokenStream,
18        annotated: TokenStream,
19    ) -> Result<TokenStream, ErrorGuaranteed> {
20        expand_requires_tts(ecx, span, annotation, annotated)
21    }
22}
23
24impl AttrProcMacro for ExpandEnsures {
25    fn expand<'cx>(
26        &self,
27        ecx: &'cx mut ExtCtxt<'_>,
28        span: Span,
29        annotation: TokenStream,
30        annotated: TokenStream,
31    ) -> Result<TokenStream, ErrorGuaranteed> {
32        expand_ensures_tts(ecx, span, annotation, annotated)
33    }
34}
35
36/// Expand the function signature to include the contract clause.
37///
38/// The contracts clause will be injected before the function body and the optional where clause.
39/// For that, we search for the body / where token, and invoke the `inject` callback to generate the
40/// contract clause in the right place.
41///
42// FIXME: this kind of manual token tree munging does not have significant precedent among
43// rustc builtin macros, probably because most builtin macros use direct AST manipulation to
44// accomplish similar goals. But since our attributes need to take arbitrary expressions, and
45// our attribute infrastructure does not yet support mixing a token-tree annotation with an AST
46// annotated, we end up doing token tree manipulation.
47fn expand_contract_clause(
48    ecx: &mut ExtCtxt<'_>,
49    attr_span: Span,
50    annotated: TokenStream,
51    inject: impl FnOnce(&mut TokenStream) -> Result<(), ErrorGuaranteed>,
52) -> Result<TokenStream, ErrorGuaranteed> {
53    let mut new_tts = TokenStream::default();
54    let mut cursor = annotated.iter();
55
56    let is_kw = |tt: &TokenTree, sym: Symbol| {
57        if let TokenTree::Token(token, _) = tt { token.is_ident_named(sym) } else { false }
58    };
59
60    // Find the `fn` keyword to check if this is a function.
61    if cursor
62        .find(|tt| {
63            new_tts.push_tree((*tt).clone());
64            is_kw(tt, kw::Fn)
65        })
66        .is_none()
67    {
68        return Err(ecx
69            .sess
70            .dcx()
71            .span_err(attr_span, "contract annotations can only be used on functions"));
72    }
73
74    // Found the `fn` keyword, now find either the `where` token or the function body.
75    let next_tt = loop {
76        let Some(tt) = cursor.next() else {
77            return Err(ecx.sess.dcx().span_err(
78                attr_span,
79                "contract annotations is only supported in functions with bodies",
80            ));
81        };
82        // If `tt` is the last element. Check if it is the function body.
83        if cursor.peek().is_none() {
84            if let TokenTree::Delimited(_, _, token::Delimiter::Brace, _) = tt {
85                break tt;
86            } else {
87                return Err(ecx.sess.dcx().span_err(
88                    attr_span,
89                    "contract annotations is only supported in functions with bodies",
90                ));
91            }
92        }
93
94        if is_kw(tt, kw::Where) {
95            break tt;
96        }
97        new_tts.push_tree(tt.clone());
98    };
99
100    // At this point, we've transcribed everything from the `fn` through the formal parameter list
101    // and return type declaration, (if any), but `tt` itself has *not* been transcribed.
102    //
103    // Now inject the AST contract form.
104    //
105    inject(&mut new_tts)?;
106
107    // Above we injected the internal AST requires/ensures construct. Now copy over all the other
108    // token trees.
109    new_tts.push_tree(next_tt.clone());
110    while let Some(tt) = cursor.next() {
111        new_tts.push_tree(tt.clone());
112        if cursor.peek().is_none()
113            && !matches!(tt, TokenTree::Delimited(_, _, token::Delimiter::Brace, _))
114        {
115            return Err(ecx.sess.dcx().span_err(
116                attr_span,
117                "contract annotations is only supported in functions with bodies",
118            ));
119        }
120    }
121
122    Ok(new_tts)
123}
124
125fn expand_requires_tts(
126    ecx: &mut ExtCtxt<'_>,
127    attr_span: Span,
128    annotation: TokenStream,
129    annotated: TokenStream,
130) -> Result<TokenStream, ErrorGuaranteed> {
131    let feature_span = ecx.with_def_site_ctxt(attr_span);
132    expand_contract_clause(ecx, attr_span, annotated, |new_tts| {
133        new_tts.push_tree(TokenTree::Token(
134            token::Token::from_ast_ident(Ident::new(kw::ContractRequires, feature_span)),
135            Spacing::Joint,
136        ));
137        new_tts.push_tree(TokenTree::Token(
138            token::Token::new(token::TokenKind::OrOr, attr_span),
139            Spacing::Alone,
140        ));
141        new_tts.push_tree(TokenTree::Delimited(
142            DelimSpan::from_single(attr_span),
143            DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
144            token::Delimiter::Parenthesis,
145            annotation,
146        ));
147        Ok(())
148    })
149}
150
151fn expand_ensures_tts(
152    ecx: &mut ExtCtxt<'_>,
153    attr_span: Span,
154    annotation: TokenStream,
155    annotated: TokenStream,
156) -> Result<TokenStream, ErrorGuaranteed> {
157    let feature_span = ecx.with_def_site_ctxt(attr_span);
158    expand_contract_clause(ecx, attr_span, annotated, |new_tts| {
159        new_tts.push_tree(TokenTree::Token(
160            token::Token::from_ast_ident(Ident::new(kw::ContractEnsures, feature_span)),
161            Spacing::Joint,
162        ));
163        new_tts.push_tree(TokenTree::Delimited(
164            DelimSpan::from_single(attr_span),
165            DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
166            token::Delimiter::Parenthesis,
167            annotation,
168        ));
169        Ok(())
170    })
171}