1#[cfg(llvm_enzyme)]
7mod llvm_enzyme {
8 use std::str::FromStr;
9 use std::string::String;
10
11 use rustc_ast::expand::autodiff_attrs::{
12 AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ty_for_activity,
13 };
14 use rustc_ast::ptr::P;
15 use rustc_ast::token::{Token, TokenKind};
16 use rustc_ast::tokenstream::*;
17 use rustc_ast::visit::AssocCtxt::*;
18 use rustc_ast::{
19 self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
20 PatKind, TyKind,
21 };
22 use rustc_expand::base::{Annotatable, ExtCtxt};
23 use rustc_span::{Ident, Span, Symbol, kw, sym};
24 use thin_vec::{ThinVec, thin_vec};
25 use tracing::{debug, trace};
26
27 use crate::errors;
28
29 fn has_ret(ty: &FnRetTy) -> bool {
32 match ty {
33 FnRetTy::Ty(ty) => !ty.kind.is_unit(),
34 FnRetTy::Default(_) => false,
35 }
36 }
37 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
38 let segments = &x.meta_item().unwrap().path.segments;
39 assert!(segments.len() == 1);
40 segments[0].ident
41 }
42
43 fn name(x: &MetaItemInner) -> String {
44 first_ident(x).name.to_string()
45 }
46
47 pub(crate) fn from_ast(
48 ecx: &mut ExtCtxt<'_>,
49 meta_item: &ThinVec<MetaItemInner>,
50 has_ret: bool,
51 ) -> AutoDiffAttrs {
52 let dcx = ecx.sess.dcx();
53 let mode = name(&meta_item[1]);
54 let Ok(mode) = DiffMode::from_str(&mode) else {
55 dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
56 return AutoDiffAttrs::error();
57 };
58 let mut activities: Vec<DiffActivity> = vec![];
59 let mut errors = false;
60 for x in &meta_item[2..] {
61 let activity_str = name(&x);
62 let res = DiffActivity::from_str(&activity_str);
63 match res {
64 Ok(x) => activities.push(x),
65 Err(_) => {
66 dcx.emit_err(errors::AutoDiffUnknownActivity {
67 span: x.span(),
68 act: activity_str,
69 });
70 errors = true;
71 }
72 };
73 }
74 if errors {
75 return AutoDiffAttrs::error();
76 }
77
78 let (ret_activity, input_activity) = if has_ret {
81 let Some((last, rest)) = activities.split_last() else {
82 unreachable!(
83 "should not be reachable because we counted the number of activities previously"
84 );
85 };
86 (last, rest)
87 } else {
88 (&DiffActivity::None, activities.as_slice())
89 };
90
91 AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
92 }
93
94 pub(crate) fn expand(
128 ecx: &mut ExtCtxt<'_>,
129 expand_span: Span,
130 meta_item: &ast::MetaItem,
131 mut item: Annotatable,
132 ) -> Vec<Annotatable> {
133 let dcx = ecx.sess.dcx();
134 let (sig, is_impl): (FnSig, bool) = match &item {
136 Annotatable::Item(ref iitem) => {
137 let sig = match &iitem.kind {
138 ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
139 _ => {
140 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
141 return vec![item];
142 }
143 };
144 (sig.clone(), false)
145 }
146 Annotatable::AssocItem(ref assoc_item, _) => {
147 let sig = match &assoc_item.kind {
148 ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
149 _ => {
150 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
151 return vec![item];
152 }
153 };
154 (sig.clone(), true)
155 }
156 _ => {
157 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
158 return vec![item];
159 }
160 };
161
162 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
163 ast::MetaItemKind::List(ref vec) => vec.clone(),
164 _ => {
165 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
166 return vec![item];
167 }
168 };
169
170 let has_ret = has_ret(&sig.decl.output);
171 let sig_span = ecx.with_call_site_ctxt(sig.span);
172
173 let (vis, primal) = match &item {
174 Annotatable::Item(ref iitem) => (iitem.vis.clone(), iitem.ident.clone()),
175 Annotatable::AssocItem(ref assoc_item, _) => {
176 (assoc_item.vis.clone(), assoc_item.ident.clone())
177 }
178 _ => {
179 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
180 return vec![item];
181 }
182 };
183
184 let comma: Token = Token::new(TokenKind::Comma, Span::default());
187 let mut ts: Vec<TokenTree> = vec![];
188 if meta_item_vec.len() < 2 {
189 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
192 return vec![item];
193 } else {
194 for t in meta_item_vec.clone()[1..].iter() {
195 let val = first_ident(t);
196 let t = Token::from_ast_ident(val);
197 ts.push(TokenTree::Token(t, Spacing::Joint));
198 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
199 }
200 }
201 if !has_ret {
202 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
205 ts.push(TokenTree::Token(t, Spacing::Joint));
206 }
207 let ts: TokenStream = TokenStream::from_iter(ts);
208
209 let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
210 if !x.is_active() {
211 return vec![item];
214 }
215 let span = ecx.with_def_site_ctxt(expand_span);
216
217 let n_active: u32 = x
218 .input_activity
219 .iter()
220 .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
221 .count() as u32;
222 let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
223 let new_decl_span = d_sig.span;
224 let d_body = gen_enzyme_body(
225 ecx,
226 &x,
227 n_active,
228 &sig,
229 &d_sig,
230 primal,
231 &new_args,
232 span,
233 sig_span,
234 new_decl_span,
235 idents,
236 errored,
237 );
238 let d_ident = first_ident(&meta_item_vec[0]);
239
240 let asdf = Box::new(ast::Fn {
242 defaultness: ast::Defaultness::Final,
243 sig: d_sig,
244 generics: Generics::default(),
245 body: Some(d_body),
246 });
247 let mut rustc_ad_attr =
248 P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
249
250 let ts2: Vec<TokenTree> = vec![TokenTree::Token(
251 Token::new(TokenKind::Ident(sym::never, false.into()), span),
252 Spacing::Joint,
253 )];
254 let never_arg = ast::DelimArgs {
255 dspan: ast::tokenstream::DelimSpan::from_single(span),
256 delim: ast::token::Delimiter::Parenthesis,
257 tokens: ast::tokenstream::TokenStream::from_iter(ts2),
258 };
259 let inline_item = ast::AttrItem {
260 unsafety: ast::Safety::Default,
261 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
262 args: ast::AttrArgs::Delimited(never_arg),
263 tokens: None,
264 };
265 let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
266 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
267 let attr: ast::Attribute = ast::Attribute {
268 kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
269 id: new_id,
270 style: ast::AttrStyle::Outer,
271 span,
272 };
273 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
274 let inline_never: ast::Attribute = ast::Attribute {
275 kind: ast::AttrKind::Normal(inline_never_attr),
276 id: new_id,
277 style: ast::AttrStyle::Outer,
278 span,
279 };
280
281 let orig_annotatable: Annotatable = match item {
283 Annotatable::Item(ref mut iitem) => {
284 if !iitem.attrs.iter().any(|a| a.id == attr.id) {
285 iitem.attrs.push(attr.clone());
286 }
287 if !iitem.attrs.iter().any(|a| a.id == inline_never.id) {
288 iitem.attrs.push(inline_never.clone());
289 }
290 Annotatable::Item(iitem.clone())
291 }
292 Annotatable::AssocItem(ref mut assoc_item, i @ Impl) => {
293 if !assoc_item.attrs.iter().any(|a| a.id == attr.id) {
294 assoc_item.attrs.push(attr.clone());
295 }
296 if !assoc_item.attrs.iter().any(|a| a.id == inline_never.id) {
297 assoc_item.attrs.push(inline_never.clone());
298 }
299 Annotatable::AssocItem(assoc_item.clone(), i)
300 }
301 _ => {
302 unreachable!("annotatable kind checked previously")
303 }
304 };
305 rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
307 dspan: DelimSpan::dummy(),
308 delim: rustc_ast::token::Delimiter::Parenthesis,
309 tokens: ts,
310 });
311 let d_attr: ast::Attribute = ast::Attribute {
312 kind: ast::AttrKind::Normal(rustc_ad_attr.clone()),
313 id: new_id,
314 style: ast::AttrStyle::Outer,
315 span,
316 };
317
318 let d_annotatable = if is_impl {
319 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
320 let d_fn = P(ast::AssocItem {
321 attrs: thin_vec![d_attr.clone(), inline_never],
322 id: ast::DUMMY_NODE_ID,
323 span,
324 vis,
325 ident: d_ident,
326 kind: assoc_item,
327 tokens: None,
328 });
329 Annotatable::AssocItem(d_fn, Impl)
330 } else {
331 let mut d_fn = ecx.item(
332 span,
333 d_ident,
334 thin_vec![d_attr.clone(), inline_never],
335 ItemKind::Fn(asdf),
336 );
337 d_fn.vis = vis;
338 Annotatable::Item(d_fn)
339 };
340
341 return vec![orig_annotatable, d_annotatable];
342 }
343
344 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
347 let mut ty = ty.clone();
348 match ty.kind {
349 TyKind::Ptr(ref mut mut_ty) => {
350 mut_ty.mutbl = ast::Mutability::Mut;
351 }
352 TyKind::Ref(_, ref mut mut_ty) => {
353 mut_ty.mutbl = ast::Mutability::Mut;
354 }
355 _ => {
356 panic!("unsupported type: {:?}", ty);
357 }
358 }
359 ty
360 }
361
362 fn gen_enzyme_body(
373 ecx: &ExtCtxt<'_>,
374 x: &AutoDiffAttrs,
375 n_active: u32,
376 sig: &ast::FnSig,
377 d_sig: &ast::FnSig,
378 primal: Ident,
379 new_names: &[String],
380 span: Span,
381 sig_span: Span,
382 new_decl_span: Span,
383 idents: Vec<Ident>,
384 errored: bool,
385 ) -> P<ast::Block> {
386 let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
387 let noop = ast::InlineAsm {
388 asm_macro: ast::AsmMacro::Asm,
389 template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
390 template_strs: Box::new([]),
391 operands: vec![],
392 clobber_abis: vec![],
393 options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
394 line_spans: vec![],
395 };
396 let noop_expr = ecx.expr_asm(span, P(noop));
397 let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
398 let unsf_block = ast::Block {
399 stmts: thin_vec![ecx.stmt_semi(noop_expr)],
400 id: ast::DUMMY_NODE_ID,
401 tokens: None,
402 rules: unsf,
403 span,
404 could_be_bare_literal: false,
405 };
406 let unsf_expr = ecx.expr_block(P(unsf_block));
407 let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
408 let primal_call = gen_primal_call(ecx, span, primal, idents);
409 let black_box_primal_call = ecx.expr_call(
410 new_decl_span,
411 blackbox_call_expr.clone(),
412 thin_vec![primal_call.clone()],
413 );
414 let tup_args = new_names
415 .iter()
416 .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
417 .collect();
418
419 let black_box_remaining_args = ecx.expr_call(
420 sig_span,
421 blackbox_call_expr.clone(),
422 thin_vec![ecx.expr_tuple(sig_span, tup_args)],
423 );
424
425 let mut body = ecx.block(span, ThinVec::new());
426 body.stmts.push(ecx.stmt_semi(unsf_expr));
427
428 if !errored {
430 body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
431 }
432 body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
433
434 if !has_ret(&d_sig.decl.output) {
435 return body;
437 }
438
439 let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
442
443 if primal_ret && n_active == 0 && x.mode.is_rev() {
444 body.stmts.push(ecx.stmt_expr(black_box_primal_call.clone()));
446 return body;
447 }
448
449 if !primal_ret && n_active == 1 {
450 let ty = match d_sig.decl.output {
452 FnRetTy::Ty(ref ty) => ty.clone(),
453 FnRetTy::Default(span) => {
454 panic!("Did not expect Default ret ty: {:?}", span);
455 }
456 };
457 let arg = ty.kind.is_simple_path().unwrap();
458 let sl: Vec<Symbol> = vec![arg, kw::Default];
459 let tmp = ecx.def_site_path(&sl);
460 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
461 let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
462 body.stmts.push(ecx.stmt_expr(default_call_expr));
463 return body;
464 }
465
466 let mut exprs = ThinVec::<P<ast::Expr>>::new();
467 if primal_ret {
468 exprs.push(primal_call.clone());
471 }
472
473 let d_ret_ty = match d_sig.decl.output {
476 FnRetTy::Ty(ref ty) => ty.clone(),
477 FnRetTy::Default(span) => {
478 panic!("Did not expect Default ret ty: {:?}", span);
479 }
480 };
481 let mut d_ret_ty = match d_ret_ty.kind.clone() {
482 TyKind::Tup(ref tys) => tys.clone(),
483 TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
484 if let [segment] = &segments[..]
485 && segment.args.is_none()
486 {
487 let id = vec![segments[0].ident];
488 let kind = TyKind::Path(None, ecx.path(span, id));
489 let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
490 thin_vec![ty]
491 } else {
492 panic!("Expected tuple or simple path return type");
493 }
494 }
495 _ => {
496 panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
498 }
499 };
500
501 if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
502 assert!(d_ret_ty.len() == 2);
503 let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
505 let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
506 assert!(arg == arg2);
507 let sl: Vec<Symbol> = vec![arg, kw::Default];
508 let tmp = ecx.def_site_path(&sl);
509 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
510 let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
511 exprs.push(default_call_expr);
512 } else if x.mode.is_rev() {
513 if primal_ret {
514 d_ret_ty = d_ret_ty[1..].to_vec().into();
516 }
517
518 for arg in d_ret_ty.iter() {
519 let arg = arg.kind.is_simple_path().unwrap();
520 let sl: Vec<Symbol> = vec![arg, kw::Default];
521 let tmp = ecx.def_site_path(&sl);
522 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
523 let default_call_expr =
524 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
525 exprs.push(default_call_expr);
526 }
527 }
528
529 let ret: P<ast::Expr>;
530 match &exprs[..] {
531 [] => {
532 assert!(!has_ret(&d_sig.decl.output));
533 return body;
535 }
536 [arg] => {
537 ret = ecx.expr_call(
538 new_decl_span,
539 blackbox_call_expr.clone(),
540 thin_vec![arg.clone()],
541 );
542 }
543 args => {
544 let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
545 ret =
546 ecx.expr_call(new_decl_span, blackbox_call_expr.clone(), thin_vec![ret_tuple]);
547 }
548 }
549 assert!(has_ret(&d_sig.decl.output));
550 body.stmts.push(ecx.stmt_expr(ret));
551
552 body
553 }
554
555 fn gen_primal_call(
556 ecx: &ExtCtxt<'_>,
557 span: Span,
558 primal: Ident,
559 idents: Vec<Ident>,
560 ) -> P<ast::Expr> {
561 let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
562 if has_self {
563 let args: ThinVec<_> =
564 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
565 let self_expr = ecx.expr_self(span);
566 ecx.expr_method_call(span, self_expr, primal, args.clone())
567 } else {
568 let args: ThinVec<_> =
569 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
570 let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
571 ecx.expr_call(span, primal_call_expr, args)
572 }
573 }
574
575 fn gen_enzyme_decl(
585 ecx: &ExtCtxt<'_>,
586 sig: &ast::FnSig,
587 x: &AutoDiffAttrs,
588 span: Span,
589 ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
590 let dcx = ecx.sess.dcx();
591 let has_ret = has_ret(&sig.decl.output);
592 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
593 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
594 if sig_args != num_activities {
595 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
596 span,
597 expected: sig_args,
598 found: num_activities,
599 });
600 return (sig.clone(), vec![], vec![], true);
602 }
603 assert!(sig.decl.inputs.len() == x.input_activity.len());
604 assert!(has_ret == x.has_ret_activity());
605 let mut d_decl = sig.decl.clone();
606 let mut d_inputs = Vec::new();
607 let mut new_inputs = Vec::new();
608 let mut idents = Vec::new();
609 let mut act_ret = ThinVec::new();
610
611 let mut errors = false;
614 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
615 if !valid_input_activity(x.mode, *activity) {
616 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
617 span,
618 mode: x.mode.to_string(),
619 act: activity.to_string(),
620 });
621 errors = true;
622 }
623 if !valid_ty_for_activity(&arg.ty, *activity) {
624 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
625 span: arg.ty.span,
626 act: activity.to_string(),
627 });
628 errors = true;
629 }
630 }
631 if errors {
632 return (sig.clone(), new_inputs, idents, true);
634 }
635 let unsafe_activities = x
636 .input_activity
637 .iter()
638 .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
639 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
640 d_inputs.push(arg.clone());
641 match activity {
642 DiffActivity::Active => {
643 act_ret.push(arg.ty.clone());
644 }
645 DiffActivity::ActiveOnly => {
646 }
649 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
650 let mut shadow_arg = arg.clone();
651 shadow_arg.ty = P(assure_mut_ref(&arg.ty));
653 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
654 ident.name
655 } else {
656 debug!("{:#?}", &shadow_arg.pat);
657 panic!("not an ident?");
658 };
659 let name: String = format!("d{}", old_name);
660 new_inputs.push(name.clone());
661 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
662 shadow_arg.pat = P(ast::Pat {
663 id: ast::DUMMY_NODE_ID,
664 kind: PatKind::Ident(BindingMode::NONE, ident, None),
665 span: shadow_arg.pat.span,
666 tokens: shadow_arg.pat.tokens.clone(),
667 });
668 d_inputs.push(shadow_arg);
669 }
670 DiffActivity::Dual | DiffActivity::DualOnly => {
671 let mut shadow_arg = arg.clone();
672 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
673 ident.name
674 } else {
675 debug!("{:#?}", &shadow_arg.pat);
676 panic!("not an ident?");
677 };
678 let name: String = format!("b{}", old_name);
679 new_inputs.push(name.clone());
680 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
681 shadow_arg.pat = P(ast::Pat {
682 id: ast::DUMMY_NODE_ID,
683 kind: PatKind::Ident(BindingMode::NONE, ident, None),
684 span: shadow_arg.pat.span,
685 tokens: shadow_arg.pat.tokens.clone(),
686 });
687 d_inputs.push(shadow_arg);
688 }
689 DiffActivity::Const => {
690 }
692 DiffActivity::None | DiffActivity::FakeActivitySize => {
693 panic!("Should not happen");
694 }
695 }
696 if let PatKind::Ident(_, ident, _) = arg.pat.kind {
697 idents.push(ident.clone());
698 } else {
699 panic!("not an ident?");
700 }
701 }
702
703 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
704 if active_only_ret {
705 assert!(x.mode.is_rev());
706 }
707
708 if x.mode.is_rev() {
711 match x.ret_activity {
712 DiffActivity::Active | DiffActivity::ActiveOnly => {
713 let ty = match d_decl.output {
714 FnRetTy::Ty(ref ty) => ty.clone(),
715 FnRetTy::Default(span) => {
716 panic!("Did not expect Default ret ty: {:?}", span);
717 }
718 };
719 let name = "dret".to_string();
720 let ident = Ident::from_str_and_span(&name, ty.span);
721 let shadow_arg = ast::Param {
722 attrs: ThinVec::new(),
723 ty: ty.clone(),
724 pat: P(ast::Pat {
725 id: ast::DUMMY_NODE_ID,
726 kind: PatKind::Ident(BindingMode::NONE, ident, None),
727 span: ty.span,
728 tokens: None,
729 }),
730 id: ast::DUMMY_NODE_ID,
731 span: ty.span,
732 is_placeholder: false,
733 };
734 d_inputs.push(shadow_arg);
735 new_inputs.push(name);
736 }
737 _ => {}
738 }
739 }
740 d_decl.inputs = d_inputs.into();
741
742 if x.mode.is_fwd() {
743 if let DiffActivity::Dual = x.ret_activity {
744 let ty = match d_decl.output {
745 FnRetTy::Ty(ref ty) => ty.clone(),
746 FnRetTy::Default(span) => {
747 panic!("Did not expect Default ret ty: {:?}", span);
748 }
749 };
750 let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
753 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
754 d_decl.output = FnRetTy::Ty(ty);
755 }
756 if let DiffActivity::DualOnly = x.ret_activity {
757 }
761 }
762
763 d_decl.output =
765 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
766
767 trace!("act_ret: {:?}", act_ret);
768
769 if act_ret.len() > 0 {
773 let ret_ty = match d_decl.output {
774 FnRetTy::Ty(ref ty) => {
775 if !active_only_ret {
776 act_ret.insert(0, ty.clone());
777 }
778 let kind = TyKind::Tup(act_ret);
779 P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
780 }
781 FnRetTy::Default(span) => {
782 if act_ret.len() == 1 {
783 act_ret[0].clone()
784 } else {
785 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
786 P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
787 }
788 }
789 };
790 d_decl.output = FnRetTy::Ty(ret_ty);
791 }
792
793 let mut d_header = sig.header.clone();
794 if unsafe_activities {
795 d_header.safety = rustc_ast::Safety::Unsafe(span);
796 }
797 let d_sig = FnSig { header: d_header, decl: d_decl, span };
798 trace!("Generated signature: {:?}", d_sig);
799 (d_sig, new_inputs, idents, false)
800 }
801}
802
803#[cfg(not(llvm_enzyme))]
804mod ad_fallback {
805 use rustc_ast::ast;
806 use rustc_expand::base::{Annotatable, ExtCtxt};
807 use rustc_span::Span;
808
809 use crate::errors;
810 pub(crate) fn expand(
811 ecx: &mut ExtCtxt<'_>,
812 _expand_span: Span,
813 meta_item: &ast::MetaItem,
814 item: Annotatable,
815 ) -> Vec<Annotatable> {
816 ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
817 return vec![item];
818 }
819}
820
821#[cfg(not(llvm_enzyme))]
822pub(crate) use ad_fallback::expand;
823#[cfg(llvm_enzyme)]
824pub(crate) use llvm_enzyme::expand;