rustc_builtin_macros/
autodiff.rs

1//! This module contains the implementation of the `#[autodiff]` attribute.
2//! Currently our linter isn't smart enough to see that each import is used in one of the two
3//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
4//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
5
6mod llvm_enzyme {
7    use std::str::FromStr;
8    use std::string::String;
9
10    use rustc_ast::expand::autodiff_attrs::{
11        AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12        valid_ty_for_activity,
13    };
14    use rustc_ast::ptr::P;
15    use rustc_ast::token::{Token, TokenKind};
16    use rustc_ast::tokenstream::*;
17    use rustc_ast::visit::AssocCtxt::*;
18    use rustc_ast::{
19        self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
20        PatKind, TyKind,
21    };
22    use rustc_expand::base::{Annotatable, ExtCtxt};
23    use rustc_span::{Ident, Span, Symbol, kw, sym};
24    use thin_vec::{ThinVec, thin_vec};
25    use tracing::{debug, trace};
26
27    use crate::errors;
28
29    pub(crate) fn outer_normal_attr(
30        kind: &P<rustc_ast::NormalAttr>,
31        id: rustc_ast::AttrId,
32        span: Span,
33    ) -> rustc_ast::Attribute {
34        let style = rustc_ast::AttrStyle::Outer;
35        let kind = rustc_ast::AttrKind::Normal(kind.clone());
36        rustc_ast::Attribute { kind, id, style, span }
37    }
38
39    // If we have a default `()` return type or explicitley `()` return type,
40    // then we often can skip doing some work.
41    fn has_ret(ty: &FnRetTy) -> bool {
42        match ty {
43            FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44            FnRetTy::Default(_) => false,
45        }
46    }
47    fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48        let segments = &x.meta_item().unwrap().path.segments;
49        assert!(segments.len() == 1);
50        segments[0].ident
51    }
52
53    fn name(x: &MetaItemInner) -> String {
54        first_ident(x).name.to_string()
55    }
56
57    pub(crate) fn from_ast(
58        ecx: &mut ExtCtxt<'_>,
59        meta_item: &ThinVec<MetaItemInner>,
60        has_ret: bool,
61    ) -> AutoDiffAttrs {
62        let dcx = ecx.sess.dcx();
63        let mode = name(&meta_item[1]);
64        let Ok(mode) = DiffMode::from_str(&mode) else {
65            dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
66            return AutoDiffAttrs::error();
67        };
68        let mut activities: Vec<DiffActivity> = vec![];
69        let mut errors = false;
70        for x in &meta_item[2..] {
71            let activity_str = name(&x);
72            let res = DiffActivity::from_str(&activity_str);
73            match res {
74                Ok(x) => activities.push(x),
75                Err(_) => {
76                    dcx.emit_err(errors::AutoDiffUnknownActivity {
77                        span: x.span(),
78                        act: activity_str,
79                    });
80                    errors = true;
81                }
82            };
83        }
84        if errors {
85            return AutoDiffAttrs::error();
86        }
87
88        // If a return type exist, we need to split the last activity,
89        // otherwise we return None as placeholder.
90        let (ret_activity, input_activity) = if has_ret {
91            let Some((last, rest)) = activities.split_last() else {
92                unreachable!(
93                    "should not be reachable because we counted the number of activities previously"
94                );
95            };
96            (last, rest)
97        } else {
98            (&DiffActivity::None, activities.as_slice())
99        };
100
101        AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
102    }
103
104    /// We expand the autodiff macro to generate a new placeholder function which passes
105    /// type-checking and can be called by users. The function body of the placeholder function will
106    /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
107    /// should just prevent early inlining and optimizations which alter the function signature.
108    /// The exact signature of the generated function depends on the configuration provided by the
109    /// user, but here is an example:
110    ///
111    /// ```
112    /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
113    /// fn sin(x: &Box<f32>) -> f32 {
114    ///     f32::sin(**x)
115    /// }
116    /// ```
117    /// which becomes expanded to:
118    /// ```
119    /// #[rustc_autodiff]
120    /// #[inline(never)]
121    /// fn sin(x: &Box<f32>) -> f32 {
122    ///     f32::sin(**x)
123    /// }
124    /// #[rustc_autodiff(Reverse, Duplicated, Active)]
125    /// #[inline(never)]
126    /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
127    ///     unsafe {
128    ///         asm!("NOP");
129    ///     };
130    ///     ::core::hint::black_box(sin(x));
131    ///     ::core::hint::black_box((dx, dret));
132    ///     ::core::hint::black_box(sin(x))
133    /// }
134    /// ```
135    /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
136    /// in CI.
137    pub(crate) fn expand(
138        ecx: &mut ExtCtxt<'_>,
139        expand_span: Span,
140        meta_item: &ast::MetaItem,
141        mut item: Annotatable,
142    ) -> Vec<Annotatable> {
143        if cfg!(not(llvm_enzyme)) {
144            ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
145            return vec![item];
146        }
147        let dcx = ecx.sess.dcx();
148        // first get the annotable item:
149        let (sig, is_impl): (FnSig, bool) = match &item {
150            Annotatable::Item(iitem) => {
151                let sig = match &iitem.kind {
152                    ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
153                    _ => {
154                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
155                        return vec![item];
156                    }
157                };
158                (sig.clone(), false)
159            }
160            Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
161                let sig = match &assoc_item.kind {
162                    ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
163                    _ => {
164                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
165                        return vec![item];
166                    }
167                };
168                (sig.clone(), true)
169            }
170            _ => {
171                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
172                return vec![item];
173            }
174        };
175
176        let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
177            ast::MetaItemKind::List(ref vec) => vec.clone(),
178            _ => {
179                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
180                return vec![item];
181            }
182        };
183
184        let has_ret = has_ret(&sig.decl.output);
185        let sig_span = ecx.with_call_site_ctxt(sig.span);
186
187        let (vis, primal) = match &item {
188            Annotatable::Item(iitem) => (iitem.vis.clone(), iitem.ident.clone()),
189            Annotatable::AssocItem(assoc_item, _) => {
190                (assoc_item.vis.clone(), assoc_item.ident.clone())
191            }
192            _ => {
193                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
194                return vec![item];
195            }
196        };
197
198        // create TokenStream from vec elemtents:
199        // meta_item doesn't have a .tokens field
200        let comma: Token = Token::new(TokenKind::Comma, Span::default());
201        let mut ts: Vec<TokenTree> = vec![];
202        if meta_item_vec.len() < 2 {
203            // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
204            // input and output args.
205            dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
206            return vec![item];
207        } else {
208            for t in meta_item_vec.clone()[1..].iter() {
209                let val = first_ident(t);
210                let t = Token::from_ast_ident(val);
211                ts.push(TokenTree::Token(t, Spacing::Joint));
212                ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
213            }
214        }
215        if !has_ret {
216            // We don't want users to provide a return activity if the function doesn't return anything.
217            // For simplicity, we just add a dummy token to the end of the list.
218            let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
219            ts.push(TokenTree::Token(t, Spacing::Joint));
220        }
221        let ts: TokenStream = TokenStream::from_iter(ts);
222
223        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
224        if !x.is_active() {
225            // We encountered an error, so we return the original item.
226            // This allows us to potentially parse other attributes.
227            return vec![item];
228        }
229        let span = ecx.with_def_site_ctxt(expand_span);
230
231        let n_active: u32 = x
232            .input_activity
233            .iter()
234            .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
235            .count() as u32;
236        let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
237        let d_body = gen_enzyme_body(
238            ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
239        );
240        let d_ident = first_ident(&meta_item_vec[0]);
241
242        // The first element of it is the name of the function to be generated
243        let asdf = Box::new(ast::Fn {
244            defaultness: ast::Defaultness::Final,
245            sig: d_sig,
246            generics: Generics::default(),
247            contract: None,
248            body: Some(d_body),
249            define_opaque: None,
250        });
251        let mut rustc_ad_attr =
252            P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
253
254        let ts2: Vec<TokenTree> = vec![TokenTree::Token(
255            Token::new(TokenKind::Ident(sym::never, false.into()), span),
256            Spacing::Joint,
257        )];
258        let never_arg = ast::DelimArgs {
259            dspan: ast::tokenstream::DelimSpan::from_single(span),
260            delim: ast::token::Delimiter::Parenthesis,
261            tokens: ast::tokenstream::TokenStream::from_iter(ts2),
262        };
263        let inline_item = ast::AttrItem {
264            unsafety: ast::Safety::Default,
265            path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
266            args: ast::AttrArgs::Delimited(never_arg),
267            tokens: None,
268        };
269        let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
270        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
271        let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
272        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
273        let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
274
275        // We're avoid duplicating the attributes `#[rustc_autodiff]` and `#[inline(never)]`.
276        fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
277            match (attr, item) {
278                (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
279                    let a = &a.item.path;
280                    let b = &b.item.path;
281                    a.segments.len() == b.segments.len()
282                        && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
283                }
284                _ => false,
285            }
286        }
287
288        // Don't add it multiple times:
289        let orig_annotatable: Annotatable = match item {
290            Annotatable::Item(ref mut iitem) => {
291                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
292                    iitem.attrs.push(attr);
293                }
294                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
295                    iitem.attrs.push(inline_never.clone());
296                }
297                Annotatable::Item(iitem.clone())
298            }
299            Annotatable::AssocItem(ref mut assoc_item, i @ Impl { of_trait: false }) => {
300                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
301                    assoc_item.attrs.push(attr);
302                }
303                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
304                    assoc_item.attrs.push(inline_never.clone());
305                }
306                Annotatable::AssocItem(assoc_item.clone(), i)
307            }
308            _ => {
309                unreachable!("annotatable kind checked previously")
310            }
311        };
312        // Now update for d_fn
313        rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
314            dspan: DelimSpan::dummy(),
315            delim: rustc_ast::token::Delimiter::Parenthesis,
316            tokens: ts,
317        });
318        let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
319        let d_annotatable = if is_impl {
320            let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
321            let d_fn = P(ast::AssocItem {
322                attrs: thin_vec![d_attr, inline_never],
323                id: ast::DUMMY_NODE_ID,
324                span,
325                vis,
326                ident: d_ident,
327                kind: assoc_item,
328                tokens: None,
329            });
330            Annotatable::AssocItem(d_fn, Impl { of_trait: false })
331        } else {
332            let mut d_fn =
333                ecx.item(span, d_ident, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
334            d_fn.vis = vis;
335            Annotatable::Item(d_fn)
336        };
337
338        return vec![orig_annotatable, d_annotatable];
339    }
340
341    // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
342    // mutable references or ptrs, because Enzyme will write into them.
343    fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
344        let mut ty = ty.clone();
345        match ty.kind {
346            TyKind::Ptr(ref mut mut_ty) => {
347                mut_ty.mutbl = ast::Mutability::Mut;
348            }
349            TyKind::Ref(_, ref mut mut_ty) => {
350                mut_ty.mutbl = ast::Mutability::Mut;
351            }
352            _ => {
353                panic!("unsupported type: {:?}", ty);
354            }
355        }
356        ty
357    }
358
359    // Will generate a body of the type:
360    // ```
361    // {
362    //   unsafe {
363    //   asm!("NOP");
364    //   }
365    //   ::core::hint::black_box(primal(args));
366    //   ::core::hint::black_box((args, ret));
367    //   <This part remains to be done by following function>
368    // }
369    // ```
370    fn init_body_helper(
371        ecx: &ExtCtxt<'_>,
372        span: Span,
373        primal: Ident,
374        new_names: &[String],
375        sig_span: Span,
376        new_decl_span: Span,
377        idents: &[Ident],
378        errored: bool,
379    ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
380        let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
381        let noop = ast::InlineAsm {
382            asm_macro: ast::AsmMacro::Asm,
383            template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
384            template_strs: Box::new([]),
385            operands: vec![],
386            clobber_abis: vec![],
387            options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
388            line_spans: vec![],
389        };
390        let noop_expr = ecx.expr_asm(span, P(noop));
391        let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
392        let unsf_block = ast::Block {
393            stmts: thin_vec![ecx.stmt_semi(noop_expr)],
394            id: ast::DUMMY_NODE_ID,
395            tokens: None,
396            rules: unsf,
397            span,
398        };
399        let unsf_expr = ecx.expr_block(P(unsf_block));
400        let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
401        let primal_call = gen_primal_call(ecx, span, primal, idents);
402        let black_box_primal_call = ecx.expr_call(
403            new_decl_span,
404            blackbox_call_expr.clone(),
405            thin_vec![primal_call.clone()],
406        );
407        let tup_args = new_names
408            .iter()
409            .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
410            .collect();
411
412        let black_box_remaining_args = ecx.expr_call(
413            sig_span,
414            blackbox_call_expr.clone(),
415            thin_vec![ecx.expr_tuple(sig_span, tup_args)],
416        );
417
418        let mut body = ecx.block(span, ThinVec::new());
419        body.stmts.push(ecx.stmt_semi(unsf_expr));
420
421        // This uses primal args which won't be available if we errored before
422        if !errored {
423            body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
424        }
425        body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
426
427        (body, primal_call, black_box_primal_call, blackbox_call_expr)
428    }
429
430    /// We only want this function to type-check, since we will replace the body
431    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
432    /// so instead we manually build something that should pass the type checker.
433    /// We also add a inline_asm line, as one more barrier for rustc to prevent inlining
434    /// or const propagation. inline_asm will also triggers an Enzyme crash if due to another
435    /// bug would ever try to accidentially differentiate this placeholder function body.
436    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
437    /// from optimizing any arguments away.
438    fn gen_enzyme_body(
439        ecx: &ExtCtxt<'_>,
440        x: &AutoDiffAttrs,
441        n_active: u32,
442        sig: &ast::FnSig,
443        d_sig: &ast::FnSig,
444        primal: Ident,
445        new_names: &[String],
446        span: Span,
447        sig_span: Span,
448        idents: Vec<Ident>,
449        errored: bool,
450    ) -> P<ast::Block> {
451        let new_decl_span = d_sig.span;
452
453        // Just adding some default inline-asm and black_box usages to prevent early inlining
454        // and optimizations which alter the function signature.
455        //
456        // The bb_primal_call is the black_box call of the primal function. We keep it around,
457        // since it has the convenient property of returning the type of the primal function,
458        // Remember, we only care to match types here.
459        // No matter which return we pick, we always wrap it into a std::hint::black_box call,
460        // to prevent rustc from propagating it into the caller.
461        let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
462            ecx,
463            span,
464            primal,
465            new_names,
466            sig_span,
467            new_decl_span,
468            &idents,
469            errored,
470        );
471
472        if !has_ret(&d_sig.decl.output) {
473            // there is no return type that we have to match, () works fine.
474            return body;
475        }
476
477        // having an active-only return means we'll drop the original return type.
478        // So that can be treated identical to not having one in the first place.
479        let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
480
481        if primal_ret && n_active == 0 && x.mode.is_rev() {
482            // We only have the primal ret.
483            body.stmts.push(ecx.stmt_expr(bb_primal_call));
484            return body;
485        }
486
487        if !primal_ret && n_active == 1 {
488            // Again no tuple return, so return default float val.
489            let ty = match d_sig.decl.output {
490                FnRetTy::Ty(ref ty) => ty.clone(),
491                FnRetTy::Default(span) => {
492                    panic!("Did not expect Default ret ty: {:?}", span);
493                }
494            };
495            let arg = ty.kind.is_simple_path().unwrap();
496            let sl: Vec<Symbol> = vec![arg, kw::Default];
497            let tmp = ecx.def_site_path(&sl);
498            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
499            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
500            body.stmts.push(ecx.stmt_expr(default_call_expr));
501            return body;
502        }
503
504        let mut exprs = ThinVec::<P<ast::Expr>>::new();
505        if primal_ret {
506            // We have both primal ret and active floats.
507            // primal ret is first, by construction.
508            exprs.push(primal_call);
509        }
510
511        // Now construct default placeholder for each active float.
512        // Is there something nicer than f32::default() and f64::default()?
513        let d_ret_ty = match d_sig.decl.output {
514            FnRetTy::Ty(ref ty) => ty.clone(),
515            FnRetTy::Default(span) => {
516                panic!("Did not expect Default ret ty: {:?}", span);
517            }
518        };
519        let mut d_ret_ty = match d_ret_ty.kind.clone() {
520            TyKind::Tup(ref tys) => tys.clone(),
521            TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
522                if let [segment] = &segments[..]
523                    && segment.args.is_none()
524                {
525                    let id = vec![segments[0].ident];
526                    let kind = TyKind::Path(None, ecx.path(span, id));
527                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
528                    thin_vec![ty]
529                } else {
530                    panic!("Expected tuple or simple path return type");
531                }
532            }
533            _ => {
534                // We messed up construction of d_sig
535                panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
536            }
537        };
538
539        if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
540            assert!(d_ret_ty.len() == 2);
541            // both should be identical, by construction
542            let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
543            let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
544            assert!(arg == arg2);
545            let sl: Vec<Symbol> = vec![arg, kw::Default];
546            let tmp = ecx.def_site_path(&sl);
547            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
548            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
549            exprs.push(default_call_expr);
550        } else if x.mode.is_rev() {
551            if primal_ret {
552                // We have extra handling above for the primal ret
553                d_ret_ty = d_ret_ty[1..].to_vec().into();
554            }
555
556            for arg in d_ret_ty.iter() {
557                let arg = arg.kind.is_simple_path().unwrap();
558                let sl: Vec<Symbol> = vec![arg, kw::Default];
559                let tmp = ecx.def_site_path(&sl);
560                let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
561                let default_call_expr =
562                    ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
563                exprs.push(default_call_expr);
564            }
565        }
566
567        let ret: P<ast::Expr>;
568        match &exprs[..] {
569            [] => {
570                assert!(!has_ret(&d_sig.decl.output));
571                // We don't have to match the return type.
572                return body;
573            }
574            [arg] => {
575                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
576            }
577            args => {
578                let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
579                ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
580            }
581        }
582        assert!(has_ret(&d_sig.decl.output));
583        body.stmts.push(ecx.stmt_expr(ret));
584
585        body
586    }
587
588    fn gen_primal_call(
589        ecx: &ExtCtxt<'_>,
590        span: Span,
591        primal: Ident,
592        idents: &[Ident],
593    ) -> P<ast::Expr> {
594        let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
595        if has_self {
596            let args: ThinVec<_> =
597                idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
598            let self_expr = ecx.expr_self(span);
599            ecx.expr_method_call(span, self_expr, primal, args)
600        } else {
601            let args: ThinVec<_> =
602                idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
603            let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
604            ecx.expr_call(span, primal_call_expr, args)
605        }
606    }
607
608    // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
609    // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
610    // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
611    // zero-initialized by Enzyme).
612    // Each argument of the primal function (and the return type if existing) must be annotated with an
613    // activity.
614    //
615    // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
616    // both), we emit an error and return the original signature. This allows us to continue parsing.
617    // FIXME(Sa4dUs): make individual activities' span available so errors
618    // can point to only the activity instead of the entire attribute
619    fn gen_enzyme_decl(
620        ecx: &ExtCtxt<'_>,
621        sig: &ast::FnSig,
622        x: &AutoDiffAttrs,
623        span: Span,
624    ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
625        let dcx = ecx.sess.dcx();
626        let has_ret = has_ret(&sig.decl.output);
627        let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
628        let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
629        if sig_args != num_activities {
630            dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
631                span,
632                expected: sig_args,
633                found: num_activities,
634            });
635            // This is not the right signature, but we can continue parsing.
636            return (sig.clone(), vec![], vec![], true);
637        }
638        assert!(sig.decl.inputs.len() == x.input_activity.len());
639        assert!(has_ret == x.has_ret_activity());
640        let mut d_decl = sig.decl.clone();
641        let mut d_inputs = Vec::new();
642        let mut new_inputs = Vec::new();
643        let mut idents = Vec::new();
644        let mut act_ret = ThinVec::new();
645
646        // We have two loops, a first one just to check the activities and types and possibly report
647        // multiple errors in one compilation session.
648        let mut errors = false;
649        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
650            if !valid_input_activity(x.mode, *activity) {
651                dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
652                    span,
653                    mode: x.mode.to_string(),
654                    act: activity.to_string(),
655                });
656                errors = true;
657            }
658            if !valid_ty_for_activity(&arg.ty, *activity) {
659                dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
660                    span: arg.ty.span,
661                    act: activity.to_string(),
662                });
663                errors = true;
664            }
665        }
666
667        if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
668            dcx.emit_err(errors::AutoDiffInvalidRetAct {
669                span,
670                mode: x.mode.to_string(),
671                act: x.ret_activity.to_string(),
672            });
673            // We don't set `errors = true` to avoid annoying type errors relative
674            // to the expanded macro type signature
675        }
676
677        if errors {
678            // This is not the right signature, but we can continue parsing.
679            return (sig.clone(), new_inputs, idents, true);
680        }
681
682        let unsafe_activities = x
683            .input_activity
684            .iter()
685            .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
686        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
687            d_inputs.push(arg.clone());
688            match activity {
689                DiffActivity::Active => {
690                    act_ret.push(arg.ty.clone());
691                }
692                DiffActivity::ActiveOnly => {
693                    // We will add the active scalar to the return type.
694                    // This is handled later.
695                }
696                DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
697                    let mut shadow_arg = arg.clone();
698                    // We += into the shadow in reverse mode.
699                    shadow_arg.ty = P(assure_mut_ref(&arg.ty));
700                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
701                        ident.name
702                    } else {
703                        debug!("{:#?}", &shadow_arg.pat);
704                        panic!("not an ident?");
705                    };
706                    let name: String = format!("d{}", old_name);
707                    new_inputs.push(name.clone());
708                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
709                    shadow_arg.pat = P(ast::Pat {
710                        id: ast::DUMMY_NODE_ID,
711                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
712                        span: shadow_arg.pat.span,
713                        tokens: shadow_arg.pat.tokens.clone(),
714                    });
715                    d_inputs.push(shadow_arg);
716                }
717                DiffActivity::Dual | DiffActivity::DualOnly => {
718                    let mut shadow_arg = arg.clone();
719                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
720                        ident.name
721                    } else {
722                        debug!("{:#?}", &shadow_arg.pat);
723                        panic!("not an ident?");
724                    };
725                    let name: String = format!("b{}", old_name);
726                    new_inputs.push(name.clone());
727                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
728                    shadow_arg.pat = P(ast::Pat {
729                        id: ast::DUMMY_NODE_ID,
730                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
731                        span: shadow_arg.pat.span,
732                        tokens: shadow_arg.pat.tokens.clone(),
733                    });
734                    d_inputs.push(shadow_arg);
735                }
736                DiffActivity::Const => {
737                    // Nothing to do here.
738                }
739                DiffActivity::None | DiffActivity::FakeActivitySize => {
740                    panic!("Should not happen");
741                }
742            }
743            if let PatKind::Ident(_, ident, _) = arg.pat.kind {
744                idents.push(ident.clone());
745            } else {
746                panic!("not an ident?");
747            }
748        }
749
750        let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
751        if active_only_ret {
752            assert!(x.mode.is_rev());
753        }
754
755        // If we return a scalar in the primal and the scalar is active,
756        // then add it as last arg to the inputs.
757        if x.mode.is_rev() {
758            match x.ret_activity {
759                DiffActivity::Active | DiffActivity::ActiveOnly => {
760                    let ty = match d_decl.output {
761                        FnRetTy::Ty(ref ty) => ty.clone(),
762                        FnRetTy::Default(span) => {
763                            panic!("Did not expect Default ret ty: {:?}", span);
764                        }
765                    };
766                    let name = "dret".to_string();
767                    let ident = Ident::from_str_and_span(&name, ty.span);
768                    let shadow_arg = ast::Param {
769                        attrs: ThinVec::new(),
770                        ty: ty.clone(),
771                        pat: P(ast::Pat {
772                            id: ast::DUMMY_NODE_ID,
773                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
774                            span: ty.span,
775                            tokens: None,
776                        }),
777                        id: ast::DUMMY_NODE_ID,
778                        span: ty.span,
779                        is_placeholder: false,
780                    };
781                    d_inputs.push(shadow_arg);
782                    new_inputs.push(name);
783                }
784                _ => {}
785            }
786        }
787        d_decl.inputs = d_inputs.into();
788
789        if x.mode.is_fwd() {
790            if let DiffActivity::Dual = x.ret_activity {
791                let ty = match d_decl.output {
792                    FnRetTy::Ty(ref ty) => ty.clone(),
793                    FnRetTy::Default(span) => {
794                        panic!("Did not expect Default ret ty: {:?}", span);
795                    }
796                };
797                // Dual can only be used for f32/f64 ret.
798                // In that case we return now a tuple with two floats.
799                let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
800                let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
801                d_decl.output = FnRetTy::Ty(ty);
802            }
803            if let DiffActivity::DualOnly = x.ret_activity {
804                // No need to change the return type,
805                // we will just return the shadow in place
806                // of the primal return.
807            }
808        }
809
810        // If we use ActiveOnly, drop the original return value.
811        d_decl.output =
812            if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
813
814        trace!("act_ret: {:?}", act_ret);
815
816        // If we have an active input scalar, add it's gradient to the
817        // return type. This might require changing the return type to a
818        // tuple.
819        if act_ret.len() > 0 {
820            let ret_ty = match d_decl.output {
821                FnRetTy::Ty(ref ty) => {
822                    if !active_only_ret {
823                        act_ret.insert(0, ty.clone());
824                    }
825                    let kind = TyKind::Tup(act_ret);
826                    P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
827                }
828                FnRetTy::Default(span) => {
829                    if act_ret.len() == 1 {
830                        act_ret[0].clone()
831                    } else {
832                        let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
833                        P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
834                    }
835                }
836            };
837            d_decl.output = FnRetTy::Ty(ret_ty);
838        }
839
840        let mut d_header = sig.header.clone();
841        if unsafe_activities {
842            d_header.safety = rustc_ast::Safety::Unsafe(span);
843        }
844        let d_sig = FnSig { header: d_header, decl: d_decl, span };
845        trace!("Generated signature: {:?}", d_sig);
846        (d_sig, new_inputs, idents, false)
847    }
848}
849
850pub(crate) use llvm_enzyme::expand;