rustc_builtin_macros/
autodiff.rs

1//! This module contains the implementation of the `#[autodiff]` attribute.
2//! Currently our linter isn't smart enough to see that each import is used in one of the two
3//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
4//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
5
6#[cfg(llvm_enzyme)]
7mod llvm_enzyme {
8    use std::str::FromStr;
9    use std::string::String;
10
11    use rustc_ast::expand::autodiff_attrs::{
12        AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity,
13    };
14    use rustc_ast::ptr::P;
15    use rustc_ast::token::{Token, TokenKind};
16    use rustc_ast::tokenstream::*;
17    use rustc_ast::visit::AssocCtxt::*;
18    use rustc_ast::{
19        self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
20        PatKind, TyKind,
21    };
22    use rustc_expand::base::{Annotatable, ExtCtxt};
23    use rustc_span::{Ident, Span, Symbol, kw, sym};
24    use thin_vec::{ThinVec, thin_vec};
25    use tracing::{debug, trace};
26
27    use crate::errors;
28
29    // If we have a default `()` return type or explicitley `()` return type,
30    // then we often can skip doing some work.
31    fn has_ret(ty: &FnRetTy) -> bool {
32        match ty {
33            FnRetTy::Ty(ty) => !ty.kind.is_unit(),
34            FnRetTy::Default(_) => false,
35        }
36    }
37    fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
38        let segments = &x.meta_item().unwrap().path.segments;
39        assert!(segments.len() == 1);
40        segments[0].ident
41    }
42
43    fn name(x: &MetaItemInner) -> String {
44        first_ident(x).name.to_string()
45    }
46
47    pub(crate) fn from_ast(
48        ecx: &mut ExtCtxt<'_>,
49        meta_item: &ThinVec<MetaItemInner>,
50        has_ret: bool,
51    ) -> AutoDiffAttrs {
52        let dcx = ecx.sess.dcx();
53        let mode = name(&meta_item[1]);
54        let Ok(mode) = DiffMode::from_str(&mode) else {
55            dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
56            return AutoDiffAttrs::error();
57        };
58        let mut activities: Vec<DiffActivity> = vec![];
59        let mut errors = false;
60        for x in &meta_item[2..] {
61            let activity_str = name(&x);
62            let res = DiffActivity::from_str(&activity_str);
63            match res {
64                Ok(x) => activities.push(x),
65                Err(_) => {
66                    dcx.emit_err(errors::AutoDiffUnknownActivity {
67                        span: x.span(),
68                        act: activity_str,
69                    });
70                    errors = true;
71                }
72            };
73        }
74        if errors {
75            return AutoDiffAttrs::error();
76        }
77
78        // If a return type exist, we need to split the last activity,
79        // otherwise we return None as placeholder.
80        let (ret_activity, input_activity) = if has_ret {
81            let Some((last, rest)) = activities.split_last() else {
82                unreachable!(
83                    "should not be reachable because we counted the number of activities previously"
84                );
85            };
86            (last, rest)
87        } else {
88            (&DiffActivity::None, activities.as_slice())
89        };
90
91        AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
92    }
93
94    /// We expand the autodiff macro to generate a new placeholder function which passes
95    /// type-checking and can be called by users. The function body of the placeholder function will
96    /// later be replaced on LLVM-IR level, so the design of the body is less important and for now
97    /// should just prevent early inlining and optimizations which alter the function signature.
98    /// The exact signature of the generated function depends on the configuration provided by the
99    /// user, but here is an example:
100    ///
101    /// ```
102    /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
103    /// fn sin(x: &Box<f32>) -> f32 {
104    ///     f32::sin(**x)
105    /// }
106    /// ```
107    /// which becomes expanded to:
108    /// ```
109    /// #[rustc_autodiff]
110    /// #[inline(never)]
111    /// fn sin(x: &Box<f32>) -> f32 {
112    ///     f32::sin(**x)
113    /// }
114    /// #[rustc_autodiff(Reverse, Duplicated, Active)]
115    /// #[inline(never)]
116    /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
117    ///     unsafe {
118    ///         asm!("NOP");
119    ///     };
120    ///     ::core::hint::black_box(sin(x));
121    ///     ::core::hint::black_box((dx, dret));
122    ///     ::core::hint::black_box(sin(x))
123    /// }
124    /// ```
125    /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
126    /// in CI.
127    pub(crate) fn expand(
128        ecx: &mut ExtCtxt<'_>,
129        expand_span: Span,
130        meta_item: &ast::MetaItem,
131        mut item: Annotatable,
132    ) -> Vec<Annotatable> {
133        let dcx = ecx.sess.dcx();
134        // first get the annotable item:
135        let (sig, is_impl): (FnSig, bool) = match &item {
136            Annotatable::Item(ref iitem) => {
137                let sig = match &iitem.kind {
138                    ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
139                    _ => {
140                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
141                        return vec![item];
142                    }
143                };
144                (sig.clone(), false)
145            }
146            Annotatable::AssocItem(ref assoc_item, _) => {
147                let sig = match &assoc_item.kind {
148                    ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
149                    _ => {
150                        dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
151                        return vec![item];
152                    }
153                };
154                (sig.clone(), true)
155            }
156            _ => {
157                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
158                return vec![item];
159            }
160        };
161
162        let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
163            ast::MetaItemKind::List(ref vec) => vec.clone(),
164            _ => {
165                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
166                return vec![item];
167            }
168        };
169
170        let has_ret = has_ret(&sig.decl.output);
171        let sig_span = ecx.with_call_site_ctxt(sig.span);
172
173        let (vis, primal) = match &item {
174            Annotatable::Item(ref iitem) => (iitem.vis.clone(), iitem.ident.clone()),
175            Annotatable::AssocItem(ref assoc_item, _) => {
176                (assoc_item.vis.clone(), assoc_item.ident.clone())
177            }
178            _ => {
179                dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
180                return vec![item];
181            }
182        };
183
184        // create TokenStream from vec elemtents:
185        // meta_item doesn't have a .tokens field
186        let comma: Token = Token::new(TokenKind::Comma, Span::default());
187        let mut ts: Vec<TokenTree> = vec![];
188        if meta_item_vec.len() < 2 {
189            // At the bare minimum, we need a fnc name and a mode, even for a dummy function with no
190            // input and output args.
191            dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
192            return vec![item];
193        } else {
194            for t in meta_item_vec.clone()[1..].iter() {
195                let val = first_ident(t);
196                let t = Token::from_ast_ident(val);
197                ts.push(TokenTree::Token(t, Spacing::Joint));
198                ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
199            }
200        }
201        if !has_ret {
202            // We don't want users to provide a return activity if the function doesn't return anything.
203            // For simplicity, we just add a dummy token to the end of the list.
204            let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
205            ts.push(TokenTree::Token(t, Spacing::Joint));
206        }
207        let ts: TokenStream = TokenStream::from_iter(ts);
208
209        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
210        if !x.is_active() {
211            // We encountered an error, so we return the original item.
212            // This allows us to potentially parse other attributes.
213            return vec![item];
214        }
215        let span = ecx.with_def_site_ctxt(expand_span);
216
217        let n_active: u32 = x
218            .input_activity
219            .iter()
220            .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
221            .count() as u32;
222        let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
223        let new_decl_span = d_sig.span;
224        let d_body = gen_enzyme_body(
225            ecx,
226            &x,
227            n_active,
228            &sig,
229            &d_sig,
230            primal,
231            &new_args,
232            span,
233            sig_span,
234            new_decl_span,
235            idents,
236            errored,
237        );
238        let d_ident = first_ident(&meta_item_vec[0]);
239
240        // The first element of it is the name of the function to be generated
241        let asdf = Box::new(ast::Fn {
242            defaultness: ast::Defaultness::Final,
243            sig: d_sig,
244            generics: Generics::default(),
245            body: Some(d_body),
246        });
247        let mut rustc_ad_attr =
248            P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
249
250        let ts2: Vec<TokenTree> = vec![TokenTree::Token(
251            Token::new(TokenKind::Ident(sym::never, false.into()), span),
252            Spacing::Joint,
253        )];
254        let never_arg = ast::DelimArgs {
255            dspan: ast::tokenstream::DelimSpan::from_single(span),
256            delim: ast::token::Delimiter::Parenthesis,
257            tokens: ast::tokenstream::TokenStream::from_iter(ts2),
258        };
259        let inline_item = ast::AttrItem {
260            unsafety: ast::Safety::Default,
261            path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
262            args: ast::AttrArgs::Delimited(never_arg),
263            tokens: None,
264        };
265        let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
266        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
267        let attr: ast::Attribute = ast::Attribute {
268            kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
269            id: new_id,
270            style: ast::AttrStyle::Outer,
271            span,
272        };
273        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
274        let inline_never: ast::Attribute = ast::Attribute {
275            kind: ast::AttrKind::Normal(inline_never_attr),
276            id: new_id,
277            style: ast::AttrStyle::Outer,
278            span,
279        };
280
281        // Don't add it multiple times:
282        let orig_annotatable: Annotatable = match item {
283            Annotatable::Item(ref mut iitem) => {
284                if !iitem.attrs.iter().any(|a| a.id == attr.id) {
285                    iitem.attrs.push(attr.clone());
286                }
287                if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
288                    iitem.attrs.push(inline_never.clone());
289                }
290                Annotatable::Item(iitem.clone())
291            }
292            Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
293                if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
294                    assoc_item.attrs.push(attr.clone());
295                }
296                if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
297                    assoc_item.attrs.push(inline_never.clone());
298                }
299                Annotatable::AssocItem(assoc_item.clone(), i)
300            }
301            _ => {
302                unreachable!("annotatable kind checked previously")
303            }
304        };
305        // Now update for d_fn
306        rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
307            dspan: DelimSpan::dummy(),
308            delim: rustc_ast::token::Delimiter::Parenthesis,
309            tokens: ts,
310        });
311        let d_attr: ast::Attribute = ast::Attribute {
312            kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
313            id: new_id,
314            style: ast::AttrStyle::Outer,
315            span,
316        };
317
318        let d_annotatable = if is_impl {
319            let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
320            let d_fn = P(ast::AssocItem {
321                attrs: thin_vec![d_attr.clone(), inline_never],
322                id: ast::DUMMY_NODE_ID,
323                span,
324                vis,
325                ident: d_ident,
326                kind: assoc_item,
327                tokens: None,
328            });
329            Annotatable::AssocItem(d_fn, Impl)
330        } else {
331            let mut d_fn = ecx.item(
332                span,
333                d_ident,
334                thin_vec![d_attr.clone(), inline_never],
335                ItemKind::Fn(asdf),
336            );
337            d_fn.vis = vis;
338            Annotatable::Item(d_fn)
339        };
340
341        return vec![orig_annotatable, d_annotatable];
342    }
343
344    // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
345    // mutable references or ptrs, because Enzyme will write into them.
346    fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
347        let mut ty = ty.clone();
348        match ty.kind {
349            TyKind::Ptr(ref mut mut_ty) => {
350                mut_ty.mutbl = ast::Mutability::Mut;
351            }
352            TyKind::Ref(_, ref mut mut_ty) => {
353                mut_ty.mutbl = ast::Mutability::Mut;
354            }
355            _ => {
356                panic!("unsupported type: {:?}", ty);
357            }
358        }
359        ty
360    }
361
362    /// We only want this function to type-check, since we will replace the body
363    /// later on llvm level. Using `loop {}` does not cover all return types anymore,
364    /// so instead we build something that should pass. We also add a inline_asm
365    /// line, as one more barrier for rustc to prevent inlining of this function.
366    /// FIXME(ZuseZ4): We still have cases of incorrect inlining across modules, see
367    /// <https://github.com/EnzymeAD/rust/issues/173>, so this isn't sufficient.
368    /// It also triggers an Enzyme crash if we due to a bug ever try to differentiate
369    /// this function (which should never happen, since it is only a placeholder).
370    /// Finally, we also add back_box usages of all input arguments, to prevent rustc
371    /// from optimizing any arguments away.
372    fn gen_enzyme_body(
373        ecx: &ExtCtxt<'_>,
374        x: &AutoDiffAttrs,
375        n_active: u32,
376        sig: &ast::FnSig,
377        d_sig: &ast::FnSig,
378        primal: Ident,
379        new_names: &[String],
380        span: Span,
381        sig_span: Span,
382        new_decl_span: Span,
383        idents: Vec<Ident>,
384        errored: bool,
385    ) -> P<ast::Block> {
386        let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
387        let noop = ast::InlineAsm {
388            asm_macro: ast::AsmMacro::Asm,
389            template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
390            template_strs: Box::new([]),
391            operands: vec![],
392            clobber_abis: vec![],
393            options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
394            line_spans: vec![],
395        };
396        let noop_expr = ecx.expr_asm(span, P(noop));
397        let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
398        let unsf_block = ast::Block {
399            stmts: thin_vec![ecx.stmt_semi(noop_expr)],
400            id: ast::DUMMY_NODE_ID,
401            tokens: None,
402            rules: unsf,
403            span,
404            could_be_bare_literal: false,
405        };
406        let unsf_expr = ecx.expr_block(P(unsf_block));
407        let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
408        let primal_call = gen_primal_call(ecx, span, primal, idents);
409        let black_box_primal_call = ecx.expr_call(
410            new_decl_span,
411            blackbox_call_expr.clone(),
412            thin_vec![primal_call.clone()],
413        );
414        let tup_args = new_names
415            .iter()
416            .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
417            .collect();
418
419        let black_box_remaining_args = ecx.expr_call(
420            sig_span,
421            blackbox_call_expr.clone(),
422            thin_vec![ecx.expr_tuple(sig_span, tup_args)],
423        );
424
425        let mut body = ecx.block(span, ThinVec::new());
426        body.stmts.push(ecx.stmt_semi(unsf_expr));
427
428        // This uses primal args which won't be available if we errored before
429        if !errored {
430            body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
431        }
432        body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
433
434        if !has_ret(&d_sig.decl.output) {
435            // there is no return type that we have to match, () works fine.
436            return body;
437        }
438
439        // having an active-only return means we'll drop the original return type.
440        // So that can be treated identical to not having one in the first place.
441        let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
442
443        if primal_ret && n_active == 0 && x.mode.is_rev() {
444            // We only have the primal ret.
445            body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
446            return body;
447        }
448
449        if !primal_ret && n_active == 1 {
450            // Again no tuple return, so return default float val.
451            let ty = match d_sig.decl.output {
452                FnRetTy::Ty(ref ty) => ty.clone(),
453                FnRetTy::Default(span) => {
454                    panic!("Did not expect Default ret ty: {:?}", span);
455                }
456            };
457            let arg = ty.kind.is_simple_path().unwrap();
458            let sl: Vec<Symbol> = vec![arg, kw::Default];
459            let tmp = ecx.def_site_path(&sl);
460            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
461            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
462            body.stmts.push(ecx.stmt_expr(default_call_expr));
463            return body;
464        }
465
466        let mut exprs = ThinVec::<P<ast::Expr>>::new();
467        if primal_ret {
468            // We have both primal ret and active floats.
469            // primal ret is first, by construction.
470            exprs.push(primal_call.clone());
471        }
472
473        // Now construct default placeholder for each active float.
474        // Is there something nicer than f32::default() and f64::default()?
475        let d_ret_ty = match d_sig.decl.output {
476            FnRetTy::Ty(ref ty) => ty.clone(),
477            FnRetTy::Default(span) => {
478                panic!("Did not expect Default ret ty: {:?}", span);
479            }
480        };
481        let mut d_ret_ty = match d_ret_ty.kind.clone() {
482            TyKind::Tup(ref tys) => tys.clone(),
483            TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
484                if let [segment] = &segments[..]
485                    && segment.args.is_none()
486                {
487                    let id = vec![segments[0].ident];
488                    let kind = TyKind::Path(None, ecx.path(span, id));
489                    let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
490                    thin_vec![ty]
491                } else {
492                    panic!("Expected tuple or simple path return type");
493                }
494            }
495            _ => {
496                // We messed up construction of d_sig
497                panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
498            }
499        };
500
501        if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
502            assert!(d_ret_ty.len() == 2);
503            // both should be identical, by construction
504            let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
505            let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
506            assert!(arg == arg2);
507            let sl: Vec<Symbol> = vec![arg, kw::Default];
508            let tmp = ecx.def_site_path(&sl);
509            let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
510            let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
511            exprs.push(default_call_expr);
512        } else if x.mode.is_rev() {
513            if primal_ret {
514                // We have extra handling above for the primal ret
515                d_ret_ty = d_ret_ty[1..].to_vec().into();
516            }
517
518            for arg in d_ret_ty.iter() {
519                let arg = arg.kind.is_simple_path().unwrap();
520                let sl: Vec<Symbol> = vec![arg, kw::Default];
521                let tmp = ecx.def_site_path(&sl);
522                let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
523                let default_call_expr =
524                    ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
525                exprs.push(default_call_expr);
526            }
527        }
528
529        let ret: P<ast::Expr>;
530        match &exprs[..] {
531            [] => {
532                assert!(!has_ret(&d_sig.decl.output));
533                // We don't have to match the return type.
534                return body;
535            }
536            [arg] => {
537                ret = ecx.expr_call(
538                    new_decl_span,
539                    blackbox_call_expr.clone(),
540                    thin_vec![arg.clone()],
541                );
542            }
543            args => {
544                let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
545                ret =
546                    ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
547            }
548        }
549        assert!(has_ret(&d_sig.decl.output));
550        body.stmts.push(ecx.stmt_expr(ret));
551
552        body
553    }
554
555    fn gen_primal_call(
556        ecx: &ExtCtxt<'_>,
557        span: Span,
558        primal: Ident,
559        idents: Vec<Ident>,
560    ) -> P<ast::Expr> {
561        let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
562        if has_self {
563            let args: ThinVec<_> =
564                idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
565            let self_expr = ecx.expr_self(span);
566            ecx.expr_method_call(span, self_expr, primal, args.clone())
567        } else {
568            let args: ThinVec<_> =
569                idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
570            let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
571            ecx.expr_call(span, primal_call_expr, args)
572        }
573    }
574
575    // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
576    // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
577    // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
578    // zero-initialized by Enzyme).
579    // Each argument of the primal function (and the return type if existing) must be annotated with an
580    // activity.
581    //
582    // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
583    // both), we emit an error and return the original signature. This allows us to continue parsing.
584    fn gen_enzyme_decl(
585        ecx: &ExtCtxt<'_>,
586        sig: &ast::FnSig,
587        x: &AutoDiffAttrs,
588        span: Span,
589    ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
590        let dcx = ecx.sess.dcx();
591        let has_ret = has_ret(&sig.decl.output);
592        let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
593        let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
594        if sig_args != num_activities {
595            dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
596                span,
597                expected: sig_args,
598                found: num_activities,
599            });
600            // This is not the right signature, but we can continue parsing.
601            return (sig.clone(), vec![], vec![], true);
602        }
603        assert!(sig.decl.inputs.len() == x.input_activity.len());
604        assert!(has_ret == x.has_ret_activity());
605        let mut d_decl = sig.decl.clone();
606        let mut d_inputs = Vec::new();
607        let mut new_inputs = Vec::new();
608        let mut idents = Vec::new();
609        let mut act_ret = ThinVec::new();
610
611        // We have two loops, a first one just to check the activities and types and possibly report
612        // multiple errors in one compilation session.
613        let mut errors = false;
614        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
615            if !valid_input_activity(x.mode, *activity) {
616                dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
617                    span,
618                    mode: x.mode.to_string(),
619                    act: activity.to_string(),
620                });
621                errors = true;
622            }
623            if !valid_ty_for_activity(&arg.ty, *activity) {
624                dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
625                    span: arg.ty.span,
626                    act: activity.to_string(),
627                });
628                errors = true;
629            }
630        }
631        if errors {
632            // This is not the right signature, but we can continue parsing.
633            return (sig.clone(), new_inputs, idents, true);
634        }
635        let unsafe_activities = x
636            .input_activity
637            .iter()
638            .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
639        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
640            d_inputs.push(arg.clone());
641            match activity {
642                DiffActivity::Active => {
643                    act_ret.push(arg.ty.clone());
644                }
645                DiffActivity::ActiveOnly => {
646                    // We will add the active scalar to the return type.
647                    // This is handled later.
648                }
649                DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
650                    let mut shadow_arg = arg.clone();
651                    // We += into the shadow in reverse mode.
652                    shadow_arg.ty = P(assure_mut_ref(&arg.ty));
653                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
654                        ident.name
655                    } else {
656                        debug!("{:#?}", &shadow_arg.pat);
657                        panic!("not an ident?");
658                    };
659                    let name: String = format!("d{}", old_name);
660                    new_inputs.push(name.clone());
661                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
662                    shadow_arg.pat = P(ast::Pat {
663                        id: ast::DUMMY_NODE_ID,
664                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
665                        span: shadow_arg.pat.span,
666                        tokens: shadow_arg.pat.tokens.clone(),
667                    });
668                    d_inputs.push(shadow_arg);
669                }
670                DiffActivity::Dual | DiffActivity::DualOnly => {
671                    let mut shadow_arg = arg.clone();
672                    let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
673                        ident.name
674                    } else {
675                        debug!("{:#?}", &shadow_arg.pat);
676                        panic!("not an ident?");
677                    };
678                    let name: String = format!("b{}", old_name);
679                    new_inputs.push(name.clone());
680                    let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
681                    shadow_arg.pat = P(ast::Pat {
682                        id: ast::DUMMY_NODE_ID,
683                        kind: PatKind::Ident(BindingMode::NONE, ident, None),
684                        span: shadow_arg.pat.span,
685                        tokens: shadow_arg.pat.tokens.clone(),
686                    });
687                    d_inputs.push(shadow_arg);
688                }
689                DiffActivity::Const => {
690                    // Nothing to do here.
691                }
692                DiffActivity::None | DiffActivity::FakeActivitySize => {
693                    panic!("Should not happen");
694                }
695            }
696            if let PatKind::Ident(_, ident, _) = arg.pat.kind {
697                idents.push(ident.clone());
698            } else {
699                panic!("not an ident?");
700            }
701        }
702
703        let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
704        if active_only_ret {
705            assert!(x.mode.is_rev());
706        }
707
708        // If we return a scalar in the primal and the scalar is active,
709        // then add it as last arg to the inputs.
710        if x.mode.is_rev() {
711            match x.ret_activity {
712                DiffActivity::Active | DiffActivity::ActiveOnly => {
713                    let ty = match d_decl.output {
714                        FnRetTy::Ty(ref ty) => ty.clone(),
715                        FnRetTy::Default(span) => {
716                            panic!("Did not expect Default ret ty: {:?}", span);
717                        }
718                    };
719                    let name = "dret".to_string();
720                    let ident = Ident::from_str_and_span(&name, ty.span);
721                    let shadow_arg = ast::Param {
722                        attrs: ThinVec::new(),
723                        ty: ty.clone(),
724                        pat: P(ast::Pat {
725                            id: ast::DUMMY_NODE_ID,
726                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
727                            span: ty.span,
728                            tokens: None,
729                        }),
730                        id: ast::DUMMY_NODE_ID,
731                        span: ty.span,
732                        is_placeholder: false,
733                    };
734                    d_inputs.push(shadow_arg);
735                    new_inputs.push(name);
736                }
737                _ => {}
738            }
739        }
740        d_decl.inputs = d_inputs.into();
741
742        if x.mode.is_fwd() {
743            if let DiffActivity::Dual = x.ret_activity {
744                let ty = match d_decl.output {
745                    FnRetTy::Ty(ref ty) => ty.clone(),
746                    FnRetTy::Default(span) => {
747                        panic!("Did not expect Default ret ty: {:?}", span);
748                    }
749                };
750                // Dual can only be used for f32/f64 ret.
751                // In that case we return now a tuple with two floats.
752                let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
753                let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
754                d_decl.output = FnRetTy::Ty(ty);
755            }
756            if let DiffActivity::DualOnly = x.ret_activity {
757                // No need to change the return type,
758                // we will just return the shadow in place
759                // of the primal return.
760            }
761        }
762
763        // If we use ActiveOnly, drop the original return value.
764        d_decl.output =
765            if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
766
767        trace!("act_ret: {:?}", act_ret);
768
769        // If we have an active input scalar, add it's gradient to the
770        // return type. This might require changing the return type to a
771        // tuple.
772        if act_ret.len() > 0 {
773            let ret_ty = match d_decl.output {
774                FnRetTy::Ty(ref ty) => {
775                    if !active_only_ret {
776                        act_ret.insert(0, ty.clone());
777                    }
778                    let kind = TyKind::Tup(act_ret);
779                    P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
780                }
781                FnRetTy::Default(span) => {
782                    if act_ret.len() == 1 {
783                        act_ret[0].clone()
784                    } else {
785                        let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
786                        P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
787                    }
788                }
789            };
790            d_decl.output = FnRetTy::Ty(ret_ty);
791        }
792
793        let mut d_header = sig.header.clone();
794        if unsafe_activities {
795            d_header.safety = rustc_ast::Safety::Unsafe(span);
796        }
797        let d_sig = FnSig { header: d_header, decl: d_decl, span };
798        trace!("Generated signature: {:?}", d_sig);
799        (d_sig, new_inputs, idents, false)
800    }
801}
802
803#[cfg(not(llvm_enzyme))]
804mod ad_fallback {
805    use rustc_ast::ast;
806    use rustc_expand::base::{Annotatable, ExtCtxt};
807    use rustc_span::Span;
808
809    use crate::errors;
810    pub(crate) fn expand(
811        ecx: &mut ExtCtxt<'_>,
812        _expand_span: Span,
813        meta_item: &ast::MetaItem,
814        item: Annotatable,
815    ) -> Vec<Annotatable> {
816        ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
817        return vec![item];
818    }
819}
820
821#[cfg(not(llvm_enzyme))]
822pub(crate) use ad_fallback::expand;
823#[cfg(llvm_enzyme)]
824pub(crate) use llvm_enzyme::expand;