1mod llvm_enzyme {
7 use std::str::FromStr;
8 use std::string::String;
9
10 use rustc_ast::expand::autodiff_attrs::{
11 AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12 valid_ty_for_activity,
13 };
14 use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
15 use rustc_ast::tokenstream::*;
16 use rustc_ast::visit::AssocCtxt::*;
17 use rustc_ast::{
18 self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
19 FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
20 MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,
21 };
22 use rustc_expand::base::{Annotatable, ExtCtxt};
23 use rustc_span::{Ident, Span, Symbol, sym};
24 use thin_vec::{ThinVec, thin_vec};
25 use tracing::{debug, trace};
26
27 use crate::errors;
28
29 pub(crate) fn outer_normal_attr(
30 kind: &Box<rustc_ast::NormalAttr>,
31 id: rustc_ast::AttrId,
32 span: Span,
33 ) -> rustc_ast::Attribute {
34 let style = rustc_ast::AttrStyle::Outer;
35 let kind = rustc_ast::AttrKind::Normal(kind.clone());
36 rustc_ast::Attribute { kind, id, style, span }
37 }
38
39 fn has_ret(ty: &FnRetTy) -> bool {
42 match ty {
43 FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44 FnRetTy::Default(_) => false,
45 }
46 }
47 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48 if let Some(l) = x.lit() {
49 match l.kind {
50 ast::LitKind::Int(val, _) => {
51 return rustc_span::Ident::from_str(val.get().to_string().as_str());
53 }
54 _ => {}
55 }
56 }
57
58 let segments = &x.meta_item().unwrap().path.segments;
59 assert!(segments.len() == 1);
60 segments[0].ident
61 }
62
63 fn name(x: &MetaItemInner) -> String {
64 first_ident(x).name.to_string()
65 }
66
67 fn width(x: &MetaItemInner) -> Option<u128> {
68 let lit = x.lit()?;
69 match lit.kind {
70 ast::LitKind::Int(x, _) => Some(x.get()),
71 _ => return None,
72 }
73 }
74
75 fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
77 match &iitem.kind {
78 ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79 Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
80 }
81 _ => None,
82 }
83 }
84
85 pub(crate) fn from_ast(
86 ecx: &mut ExtCtxt<'_>,
87 meta_item: &ThinVec<MetaItemInner>,
88 has_ret: bool,
89 mode: DiffMode,
90 ) -> AutoDiffAttrs {
91 let dcx = ecx.sess.dcx();
92
93 let mut first_activity = 1;
96
97 let width = if let [_, x, ..] = &meta_item[..]
98 && let Some(x) = width(x)
99 {
100 first_activity = 2;
101 match x.try_into() {
102 Ok(x) => x,
103 Err(_) => {
104 dcx.emit_err(errors::AutoDiffInvalidWidth {
105 span: meta_item[1].span(),
106 width: x,
107 });
108 return AutoDiffAttrs::error();
109 }
110 }
111 } else {
112 1
113 };
114
115 let mut activities: Vec<DiffActivity> = vec![];
116 let mut errors = false;
117 for x in &meta_item[first_activity..] {
118 let activity_str = name(&x);
119 let res = DiffActivity::from_str(&activity_str);
120 match res {
121 Ok(x) => activities.push(x),
122 Err(_) => {
123 dcx.emit_err(errors::AutoDiffUnknownActivity {
124 span: x.span(),
125 act: activity_str,
126 });
127 errors = true;
128 }
129 };
130 }
131 if errors {
132 return AutoDiffAttrs::error();
133 }
134
135 let (ret_activity, input_activity) = if has_ret {
138 let Some((last, rest)) = activities.split_last() else {
139 unreachable!(
140 "should not be reachable because we counted the number of activities previously"
141 );
142 };
143 (last, rest)
144 } else {
145 (&DiffActivity::None, activities.as_slice())
146 };
147
148 AutoDiffAttrs {
149 mode,
150 width,
151 ret_activity: *ret_activity,
152 input_activity: input_activity.to_vec(),
153 }
154 }
155
156 fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
157 let comma: Token = Token::new(TokenKind::Comma, Span::default());
158 let val = first_ident(t);
159 let t = Token::from_ast_ident(val);
160 ts.push(TokenTree::Token(t, Spacing::Joint));
161 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
162 }
163
164 pub(crate) fn expand_forward(
165 ecx: &mut ExtCtxt<'_>,
166 expand_span: Span,
167 meta_item: &ast::MetaItem,
168 item: Annotatable,
169 ) -> Vec<Annotatable> {
170 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
171 }
172
173 pub(crate) fn expand_reverse(
174 ecx: &mut ExtCtxt<'_>,
175 expand_span: Span,
176 meta_item: &ast::MetaItem,
177 item: Annotatable,
178 ) -> Vec<Annotatable> {
179 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
180 }
181
182 pub(crate) fn expand_with_mode(
206 ecx: &mut ExtCtxt<'_>,
207 expand_span: Span,
208 meta_item: &ast::MetaItem,
209 mut item: Annotatable,
210 mode: DiffMode,
211 ) -> Vec<Annotatable> {
212 let dcx = ecx.sess.dcx();
213
214 let Some((vis, sig, primal, generics, impl_of_trait)) = (match &item {
218 Annotatable::Item(iitem) => {
219 extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
220 }
221 Annotatable::Stmt(stmt) => match &stmt.kind {
222 ast::StmtKind::Item(iitem) => {
223 extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
224 }
225 _ => None,
226 },
227 Annotatable::AssocItem(assoc_item, Impl { of_trait }) => match &assoc_item.kind {
228 ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
229 assoc_item.vis.clone(),
230 sig.clone(),
231 ident.clone(),
232 generics.clone(),
233 *of_trait,
234 )),
235 _ => None,
236 },
237 _ => None,
238 }) else {
239 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
240 return vec![item];
241 };
242
243 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
244 ast::MetaItemKind::List(ref vec) => vec.clone(),
245 _ => {
246 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
247 return vec![item];
248 }
249 };
250
251 let has_ret = has_ret(&sig.decl.output);
252
253 let mut ts: Vec<TokenTree> = vec![];
256 if meta_item_vec.len() < 1 {
257 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
259 return vec![item];
260 }
261
262 let mode_symbol = match mode {
263 DiffMode::Forward => sym::Forward,
264 DiffMode::Reverse => sym::Reverse,
265 _ => unreachable!("Unsupported mode: {:?}", mode),
266 };
267
268 let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
270 ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
271 ts.insert(
272 1,
273 TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
274 );
275
276 let start_position;
279 let kind: LitKind = LitKind::Integer;
280 let symbol;
281 if meta_item_vec.len() >= 2
282 && let Some(width) = width(&meta_item_vec[1])
283 {
284 start_position = 2;
285 symbol = Symbol::intern(&width.to_string());
286 } else {
287 start_position = 1;
288 symbol = sym::integer(1);
289 }
290
291 let l: Lit = Lit { kind, symbol, suffix: None };
292 let t = Token::new(TokenKind::Literal(l), Span::default());
293 let comma = Token::new(TokenKind::Comma, Span::default());
294 ts.push(TokenTree::Token(t, Spacing::Joint));
295 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
296
297 for t in meta_item_vec.clone()[start_position..].iter() {
298 meta_item_inner_to_ts(t, &mut ts);
299 }
300
301 if !has_ret {
302 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
305 ts.push(TokenTree::Token(t, Spacing::Joint));
306 ts.push(TokenTree::Token(comma, Spacing::Alone));
307 }
308 ts.pop();
310 let ts: TokenStream = TokenStream::from_iter(ts);
311
312 let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret, mode);
313 if !x.is_active() {
314 return vec![item];
317 }
318 let span = ecx.with_def_site_ctxt(expand_span);
319
320 let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
321
322 let d_body = ecx.block(
323 span,
324 thin_vec![call_autodiff(
325 ecx,
326 primal,
327 first_ident(&meta_item_vec[0]),
328 span,
329 &d_sig,
330 &generics,
331 impl_of_trait,
332 )],
333 );
334
335 let d_fn = Box::new(ast::Fn {
337 defaultness: ast::Defaultness::Final,
338 sig: d_sig,
339 ident: first_ident(&meta_item_vec[0]),
340 generics,
341 contract: None,
342 body: Some(d_body),
343 define_opaque: None,
344 eii_impls: ThinVec::new(),
345 });
346 let mut rustc_ad_attr =
347 Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
348
349 let ts2: Vec<TokenTree> = vec![TokenTree::Token(
350 Token::new(TokenKind::Ident(sym::never, false.into()), span),
351 Spacing::Joint,
352 )];
353 let never_arg = ast::DelimArgs {
354 dspan: DelimSpan::from_single(span),
355 delim: ast::token::Delimiter::Parenthesis,
356 tokens: TokenStream::from_iter(ts2),
357 };
358 let inline_item = ast::AttrItem {
359 unsafety: ast::Safety::Default,
360 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
361 args: rustc_ast::ast::AttrItemKind::Unparsed(ast::AttrArgs::Delimited(never_arg)),
362 tokens: None,
363 };
364 let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });
365 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
366 let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
367 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
368 let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
369
370 fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
372 match (attr, item) {
373 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
374 let a = &a.item.path;
375 let b = &b.item.path;
376 a.segments.iter().eq_by(&b.segments, |a, b| a.ident == b.ident)
377 }
378 _ => false,
379 }
380 }
381
382 let mut has_inline_never = false;
383
384 let orig_annotatable: Annotatable = match item {
386 Annotatable::Item(ref mut iitem) => {
387 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
388 iitem.attrs.push(attr);
389 }
390 if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
391 has_inline_never = true;
392 }
393 Annotatable::Item(iitem.clone())
394 }
395 Annotatable::AssocItem(ref mut assoc_item, i @ Impl { .. }) => {
396 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
397 assoc_item.attrs.push(attr);
398 }
399 if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
400 has_inline_never = true;
401 }
402 Annotatable::AssocItem(assoc_item.clone(), i)
403 }
404 Annotatable::Stmt(ref mut stmt) => {
405 match stmt.kind {
406 ast::StmtKind::Item(ref mut iitem) => {
407 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
408 iitem.attrs.push(attr);
409 }
410 if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
411 has_inline_never = true;
412 }
413 }
414 _ => unreachable!("stmt kind checked previously"),
415 };
416
417 Annotatable::Stmt(stmt.clone())
418 }
419 _ => {
420 unreachable!("annotatable kind checked previously")
421 }
422 };
423 rustc_ad_attr.item.args = rustc_ast::ast::AttrItemKind::Unparsed(
425 rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
426 dspan: DelimSpan::dummy(),
427 delim: rustc_ast::token::Delimiter::Parenthesis,
428 tokens: ts,
429 }),
430 );
431
432 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
433 let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
434
435 let mut d_attrs = thin_vec![d_attr];
437
438 if has_inline_never {
439 d_attrs.push(inline_never);
440 }
441
442 let d_annotatable = match &item {
443 Annotatable::AssocItem(_, _) => {
444 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
445 let d_fn = Box::new(ast::AssocItem {
446 attrs: d_attrs,
447 id: ast::DUMMY_NODE_ID,
448 span,
449 vis,
450 kind: assoc_item,
451 tokens: None,
452 });
453 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
454 }
455 Annotatable::Item(_) => {
456 let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
457 d_fn.vis = vis;
458
459 Annotatable::Item(d_fn)
460 }
461 Annotatable::Stmt(_) => {
462 let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
463 d_fn.vis = vis;
464
465 Annotatable::Stmt(Box::new(ast::Stmt {
466 id: ast::DUMMY_NODE_ID,
467 kind: ast::StmtKind::Item(d_fn),
468 span,
469 }))
470 }
471 _ => {
472 unreachable!("item kind checked previously")
473 }
474 };
475
476 return vec![orig_annotatable, d_annotatable];
477 }
478
479 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
482 let mut ty = ty.clone();
483 match ty.kind {
484 TyKind::Ptr(ref mut mut_ty) => {
485 mut_ty.mutbl = ast::Mutability::Mut;
486 }
487 TyKind::Ref(_, ref mut mut_ty) => {
488 mut_ty.mutbl = ast::Mutability::Mut;
489 }
490 _ => {
491 panic!("unsupported type: {:?}", ty);
492 }
493 }
494 ty
495 }
496
497 fn call_autodiff(
502 ecx: &ExtCtxt<'_>,
503 primal: Ident,
504 diff: Ident,
505 span: Span,
506 d_sig: &FnSig,
507 generics: &Generics,
508 is_impl: bool,
509 ) -> rustc_ast::Stmt {
510 let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
511 let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
512
513 let tuple_expr = ecx.expr_tuple(
514 span,
515 d_sig
516 .decl
517 .inputs
518 .iter()
519 .map(|arg| match arg.pat.kind {
520 PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
521 _ => todo!(),
522 })
523 .collect::<ThinVec<_>>()
524 .into(),
525 );
526
527 let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);
528 let enzyme_path = ecx.path(span, enzyme_path_idents);
529 let call_expr = ecx.expr_call(
530 span,
531 ecx.expr_path(enzyme_path),
532 vec![primal_path_expr, diff_path_expr, tuple_expr].into(),
533 );
534
535 ecx.stmt_expr(call_expr)
536 }
537
538 fn gen_turbofish_expr(
542 ecx: &ExtCtxt<'_>,
543 ident: Ident,
544 generics: &Generics,
545 span: Span,
546 is_impl: bool,
547 ) -> Box<ast::Expr> {
548 let generic_args = generics
549 .params
550 .iter()
551 .filter_map(|p| match &p.kind {
552 GenericParamKind::Type { .. } => {
553 let path = ast::Path::from_ident(p.ident);
554 let ty = ecx.ty_path(path);
555 Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
556 }
557 GenericParamKind::Const { .. } => {
558 let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
559 let anon_const = AnonConst {
560 id: ast::DUMMY_NODE_ID,
561 value: expr,
562 mgca_disambiguation: MgcaDisambiguation::Direct,
563 };
564 Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
565 }
566 GenericParamKind::Lifetime { .. } => None,
567 })
568 .collect::<ThinVec<_>>();
569
570 let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };
571
572 let segment = PathSegment {
573 ident,
574 id: ast::DUMMY_NODE_ID,
575 args: Some(Box::new(GenericArgs::AngleBracketed(args))),
576 };
577
578 let segments = if is_impl {
579 thin_vec![
580 PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },
581 segment,
582 ]
583 } else {
584 thin_vec![segment]
585 };
586
587 let path = Path { span, segments, tokens: None };
588
589 ecx.expr_path(path)
590 }
591
592 fn gen_enzyme_decl(
604 ecx: &ExtCtxt<'_>,
605 sig: &ast::FnSig,
606 x: &AutoDiffAttrs,
607 span: Span,
608 ) -> ast::FnSig {
609 let dcx = ecx.sess.dcx();
610 let has_ret = has_ret(&sig.decl.output);
611 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
612 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
613 if sig_args != num_activities {
614 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
615 span,
616 expected: sig_args,
617 found: num_activities,
618 });
619 return sig.clone();
621 }
622 assert!(sig.decl.inputs.len() == x.input_activity.len());
623 assert!(has_ret == x.has_ret_activity());
624 let mut d_decl = sig.decl.clone();
625 let mut d_inputs = Vec::new();
626 let mut new_inputs = Vec::new();
627 let mut idents = Vec::new();
628 let mut act_ret = ThinVec::new();
629
630 let mut errors = false;
633 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
634 if !valid_input_activity(x.mode, *activity) {
635 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
636 span,
637 mode: x.mode.to_string(),
638 act: activity.to_string(),
639 });
640 errors = true;
641 }
642 if !valid_ty_for_activity(&arg.ty, *activity) {
643 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
644 span: arg.ty.span,
645 act: activity.to_string(),
646 });
647 errors = true;
648 }
649 }
650
651 if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
652 dcx.emit_err(errors::AutoDiffInvalidRetAct {
653 span,
654 mode: x.mode.to_string(),
655 act: x.ret_activity.to_string(),
656 });
657 }
660
661 if errors {
662 return sig.clone();
664 }
665
666 let unsafe_activities = x
667 .input_activity
668 .iter()
669 .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
670 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
671 d_inputs.push(arg.clone());
672 match activity {
673 DiffActivity::Active => {
674 act_ret.push(arg.ty.clone());
675 }
677 DiffActivity::ActiveOnly => {
678 }
681 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
682 for i in 0..x.width {
683 let mut shadow_arg = arg.clone();
684 shadow_arg.ty = Box::new(assure_mut_ref(&arg.ty));
686 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
687 ident.name
688 } else {
689 debug!("{:#?}", &shadow_arg.pat);
690 panic!("not an ident?");
691 };
692 let name: String = format!("d{}_{}", old_name, i);
693 new_inputs.push(name.clone());
694 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
695 shadow_arg.pat = Box::new(ast::Pat {
696 id: ast::DUMMY_NODE_ID,
697 kind: PatKind::Ident(BindingMode::NONE, ident, None),
698 span: shadow_arg.pat.span,
699 tokens: shadow_arg.pat.tokens.clone(),
700 });
701 d_inputs.push(shadow_arg.clone());
702 }
703 }
704 DiffActivity::Dual
705 | DiffActivity::DualOnly
706 | DiffActivity::Dualv
707 | DiffActivity::DualvOnly => {
708 let iterations =
711 if matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
712 1
713 } else {
714 x.width
715 };
716 for i in 0..iterations {
717 let mut shadow_arg = arg.clone();
718 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
719 ident.name
720 } else {
721 debug!("{:#?}", &shadow_arg.pat);
722 panic!("not an ident?");
723 };
724 let name: String = format!("b{}_{}", old_name, i);
725 new_inputs.push(name.clone());
726 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
727 shadow_arg.pat = Box::new(ast::Pat {
728 id: ast::DUMMY_NODE_ID,
729 kind: PatKind::Ident(BindingMode::NONE, ident, None),
730 span: shadow_arg.pat.span,
731 tokens: shadow_arg.pat.tokens.clone(),
732 });
733 d_inputs.push(shadow_arg.clone());
734 }
735 }
736 DiffActivity::Const => {
737 }
739 DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
740 panic!("Should not happen");
741 }
742 }
743 if let PatKind::Ident(_, ident, _) = arg.pat.kind {
744 idents.push(ident.clone());
745 } else {
746 panic!("not an ident?");
747 }
748 }
749
750 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
751 if active_only_ret {
752 assert!(x.mode.is_rev());
753 }
754
755 if x.mode.is_rev() {
758 match x.ret_activity {
759 DiffActivity::Active | DiffActivity::ActiveOnly => {
760 let ty = match d_decl.output {
761 FnRetTy::Ty(ref ty) => ty.clone(),
762 FnRetTy::Default(span) => {
763 panic!("Did not expect Default ret ty: {:?}", span);
764 }
765 };
766 let name = "dret".to_string();
767 let ident = Ident::from_str_and_span(&name, ty.span);
768 let shadow_arg = ast::Param {
769 attrs: ThinVec::new(),
770 ty: ty.clone(),
771 pat: Box::new(ast::Pat {
772 id: ast::DUMMY_NODE_ID,
773 kind: PatKind::Ident(BindingMode::NONE, ident, None),
774 span: ty.span,
775 tokens: None,
776 }),
777 id: ast::DUMMY_NODE_ID,
778 span: ty.span,
779 is_placeholder: false,
780 };
781 d_inputs.push(shadow_arg);
782 new_inputs.push(name);
783 }
784 _ => {}
785 }
786 }
787 d_decl.inputs = d_inputs.into();
788
789 if x.mode.is_fwd() {
790 let ty = match d_decl.output {
791 FnRetTy::Ty(ref ty) => ty.clone(),
792 FnRetTy::Default(span) => {
793 let kind = TyKind::Tup(ThinVec::new());
795 let ty = Box::new(rustc_ast::Ty {
796 kind,
797 id: ast::DUMMY_NODE_ID,
798 span,
799 tokens: None,
800 });
801 d_decl.output = FnRetTy::Ty(ty.clone());
802 assert!(matches!(x.ret_activity, DiffActivity::None));
803 ty
805 }
806 };
807
808 if matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
809 let kind = if x.width == 1 || matches!(x.ret_activity, DiffActivity::Dualv) {
810 TyKind::Tup(thin_vec![ty.clone(), ty.clone()])
813 } else {
814 let anon_const = rustc_ast::AnonConst {
816 id: ast::DUMMY_NODE_ID,
817 value: ecx.expr_usize(span, 1 + x.width as usize),
818 mgca_disambiguation: MgcaDisambiguation::Direct,
819 };
820 TyKind::Array(ty.clone(), anon_const)
821 };
822 let ty = Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
823 d_decl.output = FnRetTy::Ty(ty);
824 }
825 if matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
826 if x.width > 1 {
830 let anon_const = rustc_ast::AnonConst {
831 id: ast::DUMMY_NODE_ID,
832 value: ecx.expr_usize(span, x.width as usize),
833 mgca_disambiguation: MgcaDisambiguation::Direct,
834 };
835 let kind = TyKind::Array(ty.clone(), anon_const);
836 let ty =
837 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
838 d_decl.output = FnRetTy::Ty(ty);
839 }
840 }
841 }
842
843 d_decl.output =
845 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
846
847 trace!("act_ret: {:?}", act_ret);
848
849 if act_ret.len() > 0 {
853 let ret_ty = match d_decl.output {
854 FnRetTy::Ty(ref ty) => {
855 if !active_only_ret {
856 act_ret.insert(0, ty.clone());
857 }
858 let kind = TyKind::Tup(act_ret);
859 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
860 }
861 FnRetTy::Default(span) => {
862 if act_ret.len() == 1 {
863 act_ret[0].clone()
864 } else {
865 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
866 Box::new(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
867 }
868 }
869 };
870 d_decl.output = FnRetTy::Ty(ret_ty);
871 }
872
873 let mut d_header = sig.header.clone();
874 if unsafe_activities {
875 d_header.safety = rustc_ast::Safety::Unsafe(span);
876 }
877 let d_sig = FnSig { header: d_header, decl: d_decl, span };
878 trace!("Generated signature: {:?}", d_sig);
879 d_sig
880 }
881}
882
883pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};