1use rustc_ast::token;
2use rustc_ast::tokenstream::{DelimSpacing, DelimSpan, Spacing, TokenStream, TokenTree};
3use rustc_errors::ErrorGuaranteed;
4use rustc_expand::base::{AttrProcMacro, ExtCtxt};
5use rustc_span::Span;
6use rustc_span::symbol::{Ident, Symbol, kw};
7
8pub(crate) struct ExpandRequires;
9
10pub(crate) struct ExpandEnsures;
11
12impl AttrProcMacro for ExpandRequires {
13 fn expand<'cx>(
14 &self,
15 ecx: &'cx mut ExtCtxt<'_>,
16 span: Span,
17 annotation: TokenStream,
18 annotated: TokenStream,
19 ) -> Result<TokenStream, ErrorGuaranteed> {
20 expand_requires_tts(ecx, span, annotation, annotated)
21 }
22}
23
24impl AttrProcMacro for ExpandEnsures {
25 fn expand<'cx>(
26 &self,
27 ecx: &'cx mut ExtCtxt<'_>,
28 span: Span,
29 annotation: TokenStream,
30 annotated: TokenStream,
31 ) -> Result<TokenStream, ErrorGuaranteed> {
32 expand_ensures_tts(ecx, span, annotation, annotated)
33 }
34}
35
36fn expand_contract_clause(
48 ecx: &mut ExtCtxt<'_>,
49 attr_span: Span,
50 annotated: TokenStream,
51 inject: impl FnOnce(&mut TokenStream) -> Result<(), ErrorGuaranteed>,
52) -> Result<TokenStream, ErrorGuaranteed> {
53 let mut new_tts = TokenStream::default();
54 let mut cursor = annotated.iter();
55
56 let is_kw = |tt: &TokenTree, sym: Symbol| {
57 if let TokenTree::Token(token, _) = tt { token.is_ident_named(sym) } else { false }
58 };
59
60 if cursor
62 .find(|tt| {
63 new_tts.push_tree((*tt).clone());
64 is_kw(tt, kw::Fn)
65 })
66 .is_none()
67 {
68 return Err(ecx
69 .sess
70 .dcx()
71 .span_err(attr_span, "contract annotations can only be used on functions"));
72 }
73
74 if new_tts.iter().any(|tt| is_kw(tt, kw::Async) || is_kw(tt, kw::Gen)) {
76 return Err(ecx.sess.dcx().span_err(
77 attr_span,
78 "contract annotations are not yet supported on async or gen functions",
79 ));
80 }
81
82 let next_tt = loop {
84 let Some(tt) = cursor.next() else {
85 return Err(ecx.sess.dcx().span_err(
86 attr_span,
87 "contract annotations is only supported in functions with bodies",
88 ));
89 };
90 if cursor.peek().is_none() {
92 if let TokenTree::Delimited(_, _, token::Delimiter::Brace, _) = tt {
93 break tt;
94 } else {
95 return Err(ecx.sess.dcx().span_err(
96 attr_span,
97 "contract annotations is only supported in functions with bodies",
98 ));
99 }
100 }
101
102 if is_kw(tt, kw::Where) {
103 break tt;
104 }
105 new_tts.push_tree(tt.clone());
106 };
107
108 inject(&mut new_tts)?;
114
115 new_tts.push_tree(next_tt.clone());
118 while let Some(tt) = cursor.next() {
119 new_tts.push_tree(tt.clone());
120 if cursor.peek().is_none()
121 && !matches!(tt, TokenTree::Delimited(_, _, token::Delimiter::Brace, _))
122 {
123 return Err(ecx.sess.dcx().span_err(
124 attr_span,
125 "contract annotations is only supported in functions with bodies",
126 ));
127 }
128 }
129
130 Ok(new_tts)
131}
132
133fn expand_requires_tts(
134 ecx: &mut ExtCtxt<'_>,
135 attr_span: Span,
136 annotation: TokenStream,
137 annotated: TokenStream,
138) -> Result<TokenStream, ErrorGuaranteed> {
139 let feature_span = ecx.with_def_site_ctxt(attr_span);
140 expand_contract_clause(ecx, attr_span, annotated, |new_tts| {
141 new_tts.push_tree(TokenTree::Token(
142 token::Token::from_ast_ident(Ident::new(kw::ContractRequires, feature_span)),
143 Spacing::Joint,
144 ));
145 new_tts.push_tree(TokenTree::Token(
146 token::Token::new(token::TokenKind::OrOr, attr_span),
147 Spacing::Alone,
148 ));
149 new_tts.push_tree(TokenTree::Delimited(
150 DelimSpan::from_single(attr_span),
151 DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
152 token::Delimiter::Brace,
153 annotation,
154 ));
155 Ok(())
156 })
157}
158
159fn expand_ensures_tts(
160 ecx: &mut ExtCtxt<'_>,
161 attr_span: Span,
162 annotation: TokenStream,
163 annotated: TokenStream,
164) -> Result<TokenStream, ErrorGuaranteed> {
165 let feature_span = ecx.with_def_site_ctxt(attr_span);
166 expand_contract_clause(ecx, attr_span, annotated, |new_tts| {
167 new_tts.push_tree(TokenTree::Token(
168 token::Token::from_ast_ident(Ident::new(kw::ContractEnsures, feature_span)),
169 Spacing::Joint,
170 ));
171 new_tts.push_tree(TokenTree::Delimited(
172 DelimSpan::from_single(attr_span),
173 DelimSpacing::new(Spacing::JointHidden, Spacing::JointHidden),
174 token::Delimiter::Brace,
175 annotation,
176 ));
177 Ok(())
178 })
179}