rustc_builtin_macros/
autodiff.rs

1//! This module contains the implementation of the `#[autodiff]` attribute.
2//! Currently our linter isn't smart enough to see that each import is used in one of the two
3//! configs (autodiff enabled or disabled), so we have to add cfg's to each import.
4//! FIXME(ZuseZ4): Remove this once we have a smarter linter.
5
6mod llvm_enzyme {
7    use std::str::FromStr;
8    use std::string::String;
9
10    use rustc_ast::expand::autodiff_attrs::{
11        AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12        valid_ty_for_activity,
13    };
14    use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
15    use rustc_ast::tokenstream::*;
16    use rustc_ast::visit::AssocCtxt::*;
17    use rustc_ast::{
18        self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
19        FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
20        MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,
21    };
22    use rustc_expand::base::{Annotatable, ExtCtxt};
23    use rustc_span::{Ident, Span, Symbol, sym};
24    use thin_vec::{ThinVec, thin_vec};
25    use tracing::{debug, trace};
26
27    use crate::errors;
28
29    pub(crate) fn outer_normal_attr(
30        kind: &Box<rustc_ast::NormalAttr>,
31        id: rustc_ast::AttrId,
32        span: Span,
33    ) -> rustc_ast::Attribute {
34        let style = rustc_ast::AttrStyle::Outer;
35        let kind = rustc_ast::AttrKind::Normal(kind.clone());
36        rustc_ast::Attribute { kind, id, style, span }
37    }
38
39    // If we have a default `()` return type or explicitley `()` return type,
40    // then we often can skip doing some work.
41    fn has_ret(ty: &FnRetTy) -> bool {
42        match ty {
43            FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44            FnRetTy::Default(_) => false,
45        }
46    }
47    fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48        if let Some(l) = x.lit() {
49            match l.kind {
50                ast::LitKind::Int(val, _) => {
51                    // get an Ident from a lit
52                    return rustc_span::Ident::from_str(val.get().to_string().as_str());
53                }
54                _ => {}
55            }
56        }
57
58        let segments = &x.meta_item().unwrap().path.segments;
59        assert!(segments.len() == 1);
60        segments[0].ident
61    }
62
63    fn name(x: &MetaItemInner) -> String {
64        first_ident(x).name.to_string()
65    }
66
67    fn width(x: &MetaItemInner) -> Option<u128> {
68        let lit = x.lit()?;
69        match lit.kind {
70            ast::LitKind::Int(x, _) => Some(x.get()),
71            _ => return None,
72        }
73    }
74
75    // Get information about the function the macro is applied to
76    fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
77        match &iitem.kind {
78            ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79                Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
80            }
81            _ => None,
82        }
83    }
84
85    pub(crate) fn from_ast(
86        ecx: &mut ExtCtxt<'_>,
87        meta_item: &ThinVec<MetaItemInner>,
88        has_ret: bool,
89        mode: DiffMode,
90    ) -> AutoDiffAttrs {
91        let dcx = ecx.sess.dcx();
92
93        // Now we check, whether the user wants autodiff in batch/vector mode, or scalar mode.
94        // If he doesn't specify an integer (=width), we default to scalar mode, thus width=1.
95        let mut first_activity = 1;
96
97        let width = if let [_, x, ..] = &meta_item[..]
98            && let Some(x) = width(x)
99        {
100            first_activity = 2;
101            match x.try_into() {
102                Ok(x) => x,
103                Err(_) => {
104                    dcx.emit_err(errors::AutoDiffInvalidWidth {
105                        span: meta_item[1].span(),
106                        width: x,
107                    });
108                    return AutoDiffAttrs::error();
109                }
110            }
111        } else {
112            1
113        };
114
115        let mut activities: Vec<DiffActivity> = vec![];
116        let mut errors = false;
117        for x in &meta_item[first_activity..] {
118            let activity_str = name(&x);
119            let res = DiffActivity::from_str(&activity_str);
120            match res {
121                Ok(x) => activities.push(x),
122                Err(_) => {
123                    dcx.emit_err(errors::AutoDiffUnknownActivity {
124                        span: x.span(),
125                        act: activity_str,
126                    });
127                    errors = true;
128                }
129            };
130        }
131        if errors {
132            return AutoDiffAttrs::error();
133        }
134
135        // If a return type exist, we need to split the last activity,
136        // otherwise we return None as placeholder.
137        let (ret_activity, input_activity) = if has_ret {
138            let Some((last, rest)) = activities.split_last() else {
139                unreachable!(
140                    "should not be reachable because we counted the number of activities previously"
141                );
142            };
143            (last, rest)
144        } else {
145            (&DiffActivity::None, activities.as_slice())
146        };
147
148        AutoDiffAttrs {
149            mode,
150            width,
151            ret_activity: *ret_activity,
152            input_activity: input_activity.to_vec(),
153        }
154    }
155
156    fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
157        let comma: Token = Token::new(TokenKind::Comma, Span::default());
158        let val = first_ident(t);
159        let t = Token::from_ast_ident(val);
160        ts.push(TokenTree::Token(t, Spacing::Joint));
161        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
162    }
163
164    pub(crate) fn expand_forward(
165        ecx: &mut ExtCtxt<'_>,
166        expand_span: Span,
167        meta_item: &ast::MetaItem,
168        item: Annotatable,
169    ) -> Vec<Annotatable> {
170        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
171    }
172
173    pub(crate) fn expand_reverse(
174        ecx: &mut ExtCtxt<'_>,
175        expand_span: Span,
176        meta_item: &ast::MetaItem,
177        item: Annotatable,
178    ) -> Vec<Annotatable> {
179        expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
180    }
181
182    /// We expand the autodiff macro to generate a new placeholder function which passes
183    /// type-checking and can be called by users. The exact signature of the generated function
184    /// depends on the configuration provided by the user, but here is an example:
185    ///
186    /// ```
187    /// #[autodiff(cos_box, Reverse, Duplicated, Active)]
188    /// fn sin(x: &Box<f32>) -> f32 {
189    ///     f32::sin(**x)
190    /// }
191    /// ```
192    /// which becomes expanded to:
193    /// ```
194    /// #[rustc_autodiff]
195    /// fn sin(x: &Box<f32>) -> f32 {
196    ///     f32::sin(**x)
197    /// }
198    /// #[rustc_autodiff(Reverse, Duplicated, Active)]
199    /// fn cos_box(x: &Box<f32>, dx: &mut Box<f32>, dret: f32) -> f32 {
200    ///     std::intrinsics::autodiff(sin::<>, cos_box::<>, (x, dx, dret))
201    /// }
202    /// ```
203    /// FIXME(ZuseZ4): Once autodiff is enabled by default, make this a doc comment which is checked
204    /// in CI.
205    pub(crate) fn expand_with_mode(
206        ecx: &mut ExtCtxt<'_>,
207        expand_span: Span,
208        meta_item: &ast::MetaItem,
209        mut item: Annotatable,
210        mode: DiffMode,
211    ) -> Vec<Annotatable> {
212        let dcx = ecx.sess.dcx();
213
214        // first get information about the annotable item: visibility, signature, name and generic
215        // parameters.
216        // these will be used to generate the differentiated version of the function
217        let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item {
218            Annotatable::Item(iitem) => {
219                extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
220            }
221            Annotatable::Stmt(stmt) => match &stmt.kind {
222                ast::StmtKind::Item(iitem) => {
223                    extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
224                }
225                _ => None,
226            },
227            Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind {
228                ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
229                    assoc_item.vis.clone(),
230                    sig.clone(),
231                    ident.clone(),
232                    generics.clone(),
233                    *of_trait,
234                )),
235                _ => None,
236            },
237            _ => None,
238        }) else {
239            dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
240            return vec![item];
241        };
242
243        let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
244            ast::MetaItemKind::List(ref vec) => vec.clone(),
245            _ => {
246                dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
247                return vec![item];
248            }
249        };
250
251        let has_ret = has_ret(&sig.decl.output);
252
253        // create TokenStream from vec elemtents:
254        // meta_item doesn't have a .tokens field
255        let mut ts: Vec<TokenTree> = vec![];
256        if meta_item_vec.len() < 1 {
257            // At the bare minimum, we need a fnc name.
258            dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
259            return vec![item];
260        }
261
262        let mode_symbol = match mode {
263            DiffMode::Forward => sym::Forward,
264            DiffMode::Reverse => sym::Reverse,
265            _ => unreachable!("Unsupported mode: {:?}", mode),
266        };
267
268        // Insert mode token
269        let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
270        ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
271        ts.insert(
272            1,
273            TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
274        );
275
276        // Now, if the user gave a width (vector aka batch-mode ad), then we copy it.
277        // If it is not given, we default to 1 (scalar mode).
278        let start_position;
279        let kind: LitKind = LitKind::Integer;
280        let symbol;
281        if meta_item_vec.len() >= 2
282            && let Some(width) = width(&meta_item_vec[1])
283        {
284            start_position = 2;
285            symbol = Symbol::intern(&width.to_string());
286        } else {
287            start_position = 1;
288            symbol = sym::integer(1);
289        }
290
291        let l: Lit = Lit { kind, symbol, suffix: None };
292        let t = Token::new(TokenKind::Literal(l), Span::default());
293        let comma = Token::new(TokenKind::Comma, Span::default());
294        ts.push(TokenTree::Token(t, Spacing::Joint));
295        ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
296
297        for t in meta_item_vec.clone()[start_position..].iter() {
298            meta_item_inner_to_ts(t, &mut ts);
299        }
300
301        if !has_ret {
302            // We don't want users to provide a return activity if the function doesn't return anything.
303            // For simplicity, we just add a dummy token to the end of the list.
304            let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
305            ts.push(TokenTree::Token(t, Spacing::Joint));
306            ts.push(TokenTree::Token(comma, Spacing::Alone));
307        }
308        // We remove the last, trailing comma.
309        ts.pop();
310        let ts: TokenStream = TokenStream::from_iter(ts);
311
312        let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
313        if !x.is_active() {
314            // We encountered an error, so we return the original item.
315            // This allows us to potentially parse other attributes.
316            return vec![item];
317        }
318        let span = ecx.with_def_site_ctxt(expand_span);
319
320        let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
321
322        let d_body = ecx.block(
323            span,
324            thin_vec![call_autodiff(
325                ecx,
326                primal,
327                first_ident(&meta_item_vec[0]),
328                span,
329                &d_sig,
330                &generics,
331                impl_of_trait,
332            )],
333        );
334
335        // The first element of it is the name of the function to be generated
336        let d_fn = Box::new(ast::Fn {
337            defaultness: ast::Defaultness::Final,
338            sig: d_sig,
339            ident: first_ident(&meta_item_vec[0]),
340            generics,
341            contract: None,
342            body: Some(d_body),
343            define_opaque: None,
344            eii_impls: ThinVec::new(),
345        });
346        let mut rustc_ad_attr =
347            Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
348
349        let ts2: Vec<TokenTree> = vec![TokenTree::Token(
350            Token::new(TokenKind::Ident(sym::never, false.into()), span),
351            Spacing::Joint,
352        )];
353        let never_arg = ast::DelimArgs {
354            dspan: DelimSpan::from_single(span),
355            delim: ast::token::Delimiter::Parenthesis,
356            tokens: TokenStream::from_iter(ts2),
357        };
358        let inline_item = ast::AttrItem {
359            unsafety: ast::Safety::Default,
360            path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
361            args: rustc_ast::ast::AttrItemKind::Unparsed(ast::AttrArgs::Delimited(never_arg)),
362            tokens: None,
363        };
364        let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });
365        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
366        let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
367        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
368        let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
369
370        // We're avoid duplicating the attribute `#[rustc_autodiff]`.
371        fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
372            match (attr, item) {
373                (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
374                    let a = &a.item.path;
375                    let b = &b.item.path;
376                    a.segments.iter().eq_by(&b.segments, |a, b| a.ident == b.ident)
377                }
378                _ => false,
379            }
380        }
381
382        let mut has_inline_never = false;
383
384        // Don't add it multiple times:
385        let orig_annotatable: Annotatable = match item {
386            Annotatable::Item(ref mut iitem) => {
387                if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
388                    iitem.attrs.push(attr);
389                }
390                if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
391                    has_inline_never = true;
392                }
393                Annotatable::Item(iitem.clone())
394            }
395            Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
396                if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
397                    assoc_item.attrs.push(attr);
398                }
399                if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
400                    has_inline_never = true;
401                }
402                Annotatable::AssocItem(assoc_item.clone(), i)
403            }
404            Annotatable::Stmt(ref mut stmt) => {
405                match stmt.kind {
406                    ast::StmtKind::Item(ref mut iitem) => {
407                        if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
408                            iitem.attrs.push(attr);
409                        }
410                        if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
411                            has_inline_never = true;
412                        }
413                    }
414                    _ => unreachable!("stmt kind checked previously"),
415                };
416
417                Annotatable::Stmt(stmt.clone())
418            }
419            _ => {
420                unreachable!("annotatable kind checked previously")
421            }
422        };
423        // Now update for d_fn
424        rustc_ad_attr.item.args = rustc_ast::ast::AttrItemKind::Unparsed(
425            rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
426                dspan: DelimSpan::dummy(),
427                delim: rustc_ast::token::Delimiter::Parenthesis,
428                tokens: ts,
429            }),
430        );
431
432        let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
433        let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
434
435        // If the source function has the `#[inline(never)]` attribute, we'll also add it to the diff function
436        let mut d_attrs = thin_vec![d_attr];
437
438        if has_inline_never {
439            d_attrs.push(inline_never);
440        }
441
442        let d_annotatable = match &item {
443            Annotatable::AssocItem(_, _) => {
444                let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
445                let d_fn = Box::new(ast::AssocItem {
446                    attrs: d_attrs,
447                    id: ast::DUMMY_NODE_ID,
448                    span,
449                    vis,
450                    kind: assoc_item,
451                    tokens: None,
452                });
453                Annotatable::AssocItem(d_fn, Impl { of_trait: false })
454            }
455            Annotatable::Item(_) => {
456                let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
457                d_fn.vis = vis;
458
459                Annotatable::Item(d_fn)
460            }
461            Annotatable::Stmt(_) => {
462                let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
463                d_fn.vis = vis;
464
465                Annotatable::Stmt(Box::new(ast::Stmt {
466                    id: ast::DUMMY_NODE_ID,
467                    kind: ast::StmtKind::Item(d_fn),
468                    span,
469                }))
470            }
471            _ => {
472                unreachable!("item kind checked previously")
473            }
474        };
475
476        return vec![orig_annotatable, d_annotatable];
477    }
478
479    // shadow arguments (the extra ones which were not in the original (primal) function), in reverse mode must be
480    // mutable references or ptrs, because Enzyme will write into them.
481    fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
482        let mut ty = ty.clone();
483        match ty.kind {
484            TyKind::Ptr(ref mut mut_ty) => {
485                mut_ty.mutbl = ast::Mutability::Mut;
486            }
487            TyKind::Ref(_, ref mut mut_ty) => {
488                mut_ty.mutbl = ast::Mutability::Mut;
489            }
490            _ => {
491                panic!("unsupported type: {:?}", ty);
492            }
493        }
494        ty
495    }
496
497    // Generate `autodiff` intrinsic call
498    // ```
499    // std::intrinsics::autodiff(source, diff, (args))
500    // ```
501    fn call_autodiff(
502        ecx: &ExtCtxt<'_>,
503        primal: Ident,
504        diff: Ident,
505        span: Span,
506        d_sig: &FnSig,
507        generics: &Generics,
508        is_impl: bool,
509    ) -> rustc_ast::Stmt {
510        let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
511        let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
512
513        let tuple_expr = ecx.expr_tuple(
514            span,
515            d_sig
516                .decl
517                .inputs
518                .iter()
519                .map(|arg| match arg.pat.kind {
520                    PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
521                    _ => todo!(),
522                })
523                .collect::<ThinVec<_>>()
524                .into(),
525        );
526
527        let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);
528        let enzyme_path = ecx.path(span, enzyme_path_idents);
529        let call_expr = ecx.expr_call(
530            span,
531            ecx.expr_path(enzyme_path),
532            vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
533        );
534
535        ecx.stmt_expr(call_expr)
536    }
537
538    // Generate turbofish expression from fn name and generics
539    // Given `foo` and `<A, B, C>` params, gen `foo::<A, B, C>`
540    // We use this expression when passing primal and diff function to the autodiff intrinsic
541    fn gen_turbofish_expr(
542        ecx: &ExtCtxt<'_>,
543        ident: Ident,
544        generics: &Generics,
545        span: Span,
546        is_impl: bool,
547    ) -> Box<ast::Expr> {
548        let generic_args = generics
549            .params
550            .iter()
551            .filter_map(|p| match &p.kind {
552                GenericParamKind::Type { .. } => {
553                    let path = ast::Path::from_ident(p.ident);
554                    let ty = ecx.ty_path(path);
555                    Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
556                }
557                GenericParamKind::Const { .. } => {
558                    let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
559                    let anon_const = AnonConst {
560                        id: ast::DUMMY_NODE_ID,
561                        value: expr,
562                        mgca_disambiguation: MgcaDisambiguation::Direct,
563                    };
564                    Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
565                }
566                GenericParamKind::Lifetime { .. } => None,
567            })
568            .collect::<ThinVec<_>>();
569
570        let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };
571
572        let segment = PathSegment {
573            ident,
574            id: ast::DUMMY_NODE_ID,
575            args: Some(Box::new(GenericArgs::AngleBracketed(args))),
576        };
577
578        let segments = if is_impl {
579            thin_vec![
580                PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },
581                segment,
582            ]
583        } else {
584            thin_vec![segment]
585        };
586
587        let path = Path { span, segments, tokens: None };
588
589        ecx.expr_path(path)
590    }
591
592    // Generate the new function declaration. Const arguments are kept as is. Duplicated arguments must
593    // be pointers or references. Those receive a shadow argument, which is a mutable reference/pointer.
594    // Active arguments must be scalars. Their shadow argument is added to the return type (and will be
595    // zero-initialized by Enzyme).
596    // Each argument of the primal function (and the return type if existing) must be annotated with an
597    // activity.
598    //
599    // Error handling: If the user provides an invalid configuration (incorrect numbers, types, or
600    // both), we emit an error and return the original signature. This allows us to continue parsing.
601    // FIXME(Sa4dUs): make individual activities' span available so errors
602    // can point to only the activity instead of the entire attribute
603    fn gen_enzyme_decl(
604        ecx: &ExtCtxt<'_>,
605        sig: &ast::FnSig,
606        x: &AutoDiffAttrs,
607        span: Span,
608    ) -> ast::FnSig {
609        let dcx = ecx.sess.dcx();
610        let has_ret = has_ret(&sig.decl.output);
611        let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
612        let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
613        if sig_args != num_activities {
614            dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
615                span,
616                expected: sig_args,
617                found: num_activities,
618            });
619            // This is not the right signature, but we can continue parsing.
620            return sig.clone();
621        }
622        assert!(sig.decl.inputs.len() == x.input_activity.len());
623        assert!(has_ret == x.has_ret_activity());
624        let mut d_decl = sig.decl.clone();
625        let mut d_inputs = Vec::new();
626        let mut new_inputs = Vec::new();
627        let mut idents = Vec::new();
628        let mut act_ret = ThinVec::new();
629
630        // We have two loops, a first one just to check the activities and types and possibly report
631        // multiple errors in one compilation session.
632        let mut errors = false;
633        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
634            if !valid_input_activity(x.mode, *activity) {
635                dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
636                    span,
637                    mode: x.mode.to_string(),
638                    act: activity.to_string(),
639                });
640                errors = true;
641            }
642            if !valid_ty_for_activity(&arg.ty, *activity) {
643                dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
644                    span: arg.ty.span,
645                    act: activity.to_string(),
646                });
647                errors = true;
648            }
649        }
650
651        if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
652            dcx.emit_err(errors::AutoDiffInvalidRetAct {
653                span,
654                mode: x.mode.to_string(),
655                act: x.ret_activity.to_string(),
656            });
657            // We don't set `errors = true` to avoid annoying type errors relative
658            // to the expanded macro type signature
659        }
660
661        if errors {
662            // This is not the right signature, but we can continue parsing.
663            return sig.clone();
664        }
665
666        let unsafe_activities = x
667            .input_activity
668            .iter()
669            .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
670        for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
671            d_inputs.push(arg.clone());
672            match activity {
673                DiffActivity::Active => {
674                    act_ret.push(arg.ty.clone());
675                    // if width =/= 1, then push [arg.ty; width] to act_ret
676                }
677                DiffActivity::ActiveOnly => {
678                    // We will add the active scalar to the return type.
679                    // This is handled later.
680                }
681                DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
682                    for i in 0..x.width {
683                        let mut shadow_arg = arg.clone();
684                        // We += into the shadow in reverse mode.
685                        shadow_arg.ty = Box::new(assure_mut_ref(&arg.ty));
686                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
687                            ident.name
688                        } else {
689                            debug!("{:#?}", &shadow_arg.pat);
690                            panic!("not an ident?");
691                        };
692                        let name: String = format!("d{}_{}", old_name, i);
693                        new_inputs.push(name.clone());
694                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
695                        shadow_arg.pat = Box::new(ast::Pat {
696                            id: ast::DUMMY_NODE_ID,
697                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
698                            span: shadow_arg.pat.span,
699                            tokens: shadow_arg.pat.tokens.clone(),
700                        });
701                        d_inputs.push(shadow_arg.clone());
702                    }
703                }
704                DiffActivity::Dual
705                | DiffActivity::DualOnly
706                | DiffActivity::Dualv
707                | DiffActivity::DualvOnly => {
708                    // the *v variants get lowered to enzyme_dupv and enzyme_dupnoneedv, which cause
709                    // Enzyme to not expect N arguments, but one argument (which is instead larger).
710                    let iterations =
711                        if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
712                            1
713                        } else {
714                            x.width
715                        };
716                    for i in 0..iterations {
717                        let mut shadow_arg = arg.clone();
718                        let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
719                            ident.name
720                        } else {
721                            debug!("{:#?}", &shadow_arg.pat);
722                            panic!("not an ident?");
723                        };
724                        let name: String = format!("b{}_{}", old_name, i);
725                        new_inputs.push(name.clone());
726                        let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
727                        shadow_arg.pat = Box::new(ast::Pat {
728                            id: ast::DUMMY_NODE_ID,
729                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
730                            span: shadow_arg.pat.span,
731                            tokens: shadow_arg.pat.tokens.clone(),
732                        });
733                        d_inputs.push(shadow_arg.clone());
734                    }
735                }
736                DiffActivity::Const => {
737                    // Nothing to do here.
738                }
739                DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
740                    panic!("Should not happen");
741                }
742            }
743            if let PatKind::Ident(_, ident, _) = arg.pat.kind {
744                idents.push(ident.clone());
745            } else {
746                panic!("not an ident?");
747            }
748        }
749
750        let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
751        if active_only_ret {
752            assert!(x.mode.is_rev());
753        }
754
755        // If we return a scalar in the primal and the scalar is active,
756        // then add it as last arg to the inputs.
757        if x.mode.is_rev() {
758            match x.ret_activity {
759                DiffActivity::Active | DiffActivity::ActiveOnly => {
760                    let ty = match d_decl.output {
761                        FnRetTy::Ty(ref ty) => ty.clone(),
762                        FnRetTy::Default(span) => {
763                            panic!("Did not expect Default ret ty: {:?}", span);
764                        }
765                    };
766                    let name = "dret".to_string();
767                    let ident = Ident::from_str_and_span(&name, ty.span);
768                    let shadow_arg = ast::Param {
769                        attrs: ThinVec::new(),
770                        ty: ty.clone(),
771                        pat: Box::new(ast::Pat {
772                            id: ast::DUMMY_NODE_ID,
773                            kind: PatKind::Ident(BindingMode::NONE, ident, None),
774                            span: ty.span,
775                            tokens: None,
776                        }),
777                        id: ast::DUMMY_NODE_ID,
778                        span: ty.span,
779                        is_placeholder: false,
780                    };
781                    d_inputs.push(shadow_arg);
782                    new_inputs.push(name);
783                }
784                _ => {}
785            }
786        }
787        d_decl.inputs = d_inputs.into();
788
789        if x.mode.is_fwd() {
790            let ty = match d_decl.output {
791                FnRetTy::Ty(ref ty) => ty.clone(),
792                FnRetTy::Default(span) => {
793                    // We want to return std::hint::black_box(()).
794                    let kind = TyKind::Tup(ThinVec::new());
795                    let ty = Box::new(rustc_ast::Ty {
796                        kind,
797                        id: ast::DUMMY_NODE_ID,
798                        span,
799                        tokens: None,
800                    });
801                    d_decl.output = FnRetTy::Ty(ty.clone());
802                    assert!(matches!(x.ret_activity, DiffActivity::None));
803                    // this won't be used below, so any type would be fine.
804                    ty
805                }
806            };
807
808            if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
809                let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
810                    // Dual can only be used for f32/f64 ret.
811                    // In that case we return now a tuple with two floats.
812                    TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
813                } else {
814                    // We have to return [T; width+1], +1 for the primal return.
815                    let anon_const = rustc_ast::AnonConst {
816                        id: ast::DUMMY_NODE_ID,
817                        value: ecx.expr_usize(span, 1 + x.width as usize),
818                        mgca_disambiguation: MgcaDisambiguation::Direct,
819                    };
820                    TyKind::Array(ty.clone(), anon_const)
821                };
822                let ty = Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
823                d_decl.output = FnRetTy::Ty(ty);
824            }
825            if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
826                // No need to change the return type,
827                // we will just return the shadow in place of the primal return.
828                // However, if we have a width > 1, then we don't return -> T, but -> [T; width]
829                if x.width > 1 {
830                    let anon_const = rustc_ast::AnonConst {
831                        id: ast::DUMMY_NODE_ID,
832                        value: ecx.expr_usize(span, x.width as usize),
833                        mgca_disambiguation: MgcaDisambiguation::Direct,
834                    };
835                    let kind = TyKind::Array(ty.clone(), anon_const);
836                    let ty =
837                        Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
838                    d_decl.output = FnRetTy::Ty(ty);
839                }
840            }
841        }
842
843        // If we use ActiveOnly, drop the original return value.
844        d_decl.output =
845            if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
846
847        trace!("act_ret: {:?}", act_ret);
848
849        // If we have an active input scalar, add it's gradient to the
850        // return type. This might require changing the return type to a
851        // tuple.
852        if act_ret.len() > 0 {
853            let ret_ty = match d_decl.output {
854                FnRetTy::Ty(ref ty) => {
855                    if !active_only_ret {
856                        act_ret.insert(0, ty.clone());
857                    }
858                    let kind = TyKind::Tup(act_ret);
859                    Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
860                }
861                FnRetTy::Default(span) => {
862                    if act_ret.len() == 1 {
863                        act_ret[0].clone()
864                    } else {
865                        let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
866                        Box::new(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
867                    }
868                }
869            };
870            d_decl.output = FnRetTy::Ty(ret_ty);
871        }
872
873        let mut d_header = sig.header.clone();
874        if unsafe_activities {
875            d_header.safety = rustc_ast::Safety::Unsafe(span);
876        }
877        let d_sig = FnSig { header: d_header, decl: d_decl, span };
878        trace!("Generated signature: {:?}", d_sig);
879        d_sig
880    }
881}
882
883pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};