1mod llvm_enzyme {
7 use std::str::FromStr;
8 use std::string::String;
9
10 use rustc_ast::expand::autodiff_attrs::{
11 AutoDiffAttrs, DiffActivity, DiffMode, valid_input_activity, valid_ret_activity,
12 valid_ty_for_activity,
13 };
14 use rustc_ast::ptr::P;
15 use rustc_ast::token::{Token, TokenKind};
16 use rustc_ast::tokenstream::*;
17 use rustc_ast::visit::AssocCtxt::*;
18 use rustc_ast::{
19 self as ast, AssocItemKind, BindingMode, FnRetTy, FnSig, Generics, ItemKind, MetaItemInner,
20 PatKind, TyKind,
21 };
22 use rustc_expand::base::{Annotatable, ExtCtxt};
23 use rustc_span::{Ident, Span, Symbol, kw, sym};
24 use thin_vec::{ThinVec, thin_vec};
25 use tracing::{debug, trace};
26
27 use crate::errors;
28
29 pub(crate) fn outer_normal_attr(
30 kind: &P<rustc_ast::NormalAttr>,
31 id: rustc_ast::AttrId,
32 span: Span,
33 ) -> rustc_ast::Attribute {
34 let style = rustc_ast::AttrStyle::Outer;
35 let kind = rustc_ast::AttrKind::Normal(kind.clone());
36 rustc_ast::Attribute { kind, id, style, span }
37 }
38
39 fn has_ret(ty: &FnRetTy) -> bool {
42 match ty {
43 FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44 FnRetTy::Default(_) => false,
45 }
46 }
47 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48 let segments = &x.meta_item().unwrap().path.segments;
49 assert!(segments.len() == 1);
50 segments[0].ident
51 }
52
53 fn name(x: &MetaItemInner) -> String {
54 first_ident(x).name.to_string()
55 }
56
57 pub(crate) fn from_ast(
58 ecx: &mut ExtCtxt<'_>,
59 meta_item: &ThinVec<MetaItemInner>,
60 has_ret: bool,
61 ) -> AutoDiffAttrs {
62 let dcx = ecx.sess.dcx();
63 let mode = name(&meta_item[1]);
64 let Ok(mode) = DiffMode::from_str(&mode) else {
65 dcx.emit_err(errors::AutoDiffInvalidMode { span: meta_item[1].span(), mode });
66 return AutoDiffAttrs::error();
67 };
68 let mut activities: Vec<DiffActivity> = vec![];
69 let mut errors = false;
70 for x in &meta_item[2..] {
71 let activity_str = name(&x);
72 let res = DiffActivity::from_str(&activity_str);
73 match res {
74 Ok(x) => activities.push(x),
75 Err(_) => {
76 dcx.emit_err(errors::AutoDiffUnknownActivity {
77 span: x.span(),
78 act: activity_str,
79 });
80 errors = true;
81 }
82 };
83 }
84 if errors {
85 return AutoDiffAttrs::error();
86 }
87
88 let (ret_activity, input_activity) = if has_ret {
91 let Some((last, rest)) = activities.split_last() else {
92 unreachable!(
93 "should not be reachable because we counted the number of activities previously"
94 );
95 };
96 (last, rest)
97 } else {
98 (&DiffActivity::None, activities.as_slice())
99 };
100
101 AutoDiffAttrs { mode, ret_activity: *ret_activity, input_activity: input_activity.to_vec() }
102 }
103
104 pub(crate) fn expand(
138 ecx: &mut ExtCtxt<'_>,
139 expand_span: Span,
140 meta_item: &ast::MetaItem,
141 mut item: Annotatable,
142 ) -> Vec<Annotatable> {
143 if cfg!(not(llvm_enzyme)) {
144 ecx.sess.dcx().emit_err(errors::AutoDiffSupportNotBuild { span: meta_item.span });
145 return vec![item];
146 }
147 let dcx = ecx.sess.dcx();
148 let (sig, is_impl): (FnSig, bool) = match &item {
150 Annotatable::Item(iitem) => {
151 let sig = match &iitem.kind {
152 ItemKind::Fn(box ast::Fn { sig, .. }) => sig,
153 _ => {
154 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
155 return vec![item];
156 }
157 };
158 (sig.clone(), false)
159 }
160 Annotatable::AssocItem(assoc_item, Impl { of_trait: false }) => {
161 let sig = match &assoc_item.kind {
162 ast::AssocItemKind::Fn(box ast::Fn { sig, .. }) => sig,
163 _ => {
164 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
165 return vec![item];
166 }
167 };
168 (sig.clone(), true)
169 }
170 _ => {
171 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
172 return vec![item];
173 }
174 };
175
176 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
177 ast::MetaItemKind::List(ref vec) => vec.clone(),
178 _ => {
179 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
180 return vec![item];
181 }
182 };
183
184 let has_ret = has_ret(&sig.decl.output);
185 let sig_span = ecx.with_call_site_ctxt(sig.span);
186
187 let (vis, primal) = match &item {
188 Annotatable::Item(iitem) => (iitem.vis.clone(), iitem.ident.clone()),
189 Annotatable::AssocItem(assoc_item, _) => {
190 (assoc_item.vis.clone(), assoc_item.ident.clone())
191 }
192 _ => {
193 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
194 return vec![item];
195 }
196 };
197
198 let comma: Token = Token::new(TokenKind::Comma, Span::default());
201 let mut ts: Vec<TokenTree> = vec![];
202 if meta_item_vec.len() < 2 {
203 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
206 return vec![item];
207 } else {
208 for t in meta_item_vec.clone()[1..].iter() {
209 let val = first_ident(t);
210 let t = Token::from_ast_ident(val);
211 ts.push(TokenTree::Token(t, Spacing::Joint));
212 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
213 }
214 }
215 if !has_ret {
216 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
219 ts.push(TokenTree::Token(t, Spacing::Joint));
220 }
221 let ts: TokenStream = TokenStream::from_iter(ts);
222
223 let x: AutoDiffAttrs = from_ast(ecx, &meta_item_vec, has_ret);
224 if !x.is_active() {
225 return vec![item];
228 }
229 let span = ecx.with_def_site_ctxt(expand_span);
230
231 let n_active: u32 = x
232 .input_activity
233 .iter()
234 .filter(|a| **a == DiffActivity::Active || **a == DiffActivity::ActiveOnly)
235 .count() as u32;
236 let (d_sig, new_args, idents, errored) = gen_enzyme_decl(ecx, &sig, &x, span);
237 let d_body = gen_enzyme_body(
238 ecx, &x, n_active, &sig, &d_sig, primal, &new_args, span, sig_span, idents, errored,
239 );
240 let d_ident = first_ident(&meta_item_vec[0]);
241
242 let asdf = Box::new(ast::Fn {
244 defaultness: ast::Defaultness::Final,
245 sig: d_sig,
246 generics: Generics::default(),
247 contract: None,
248 body: Some(d_body),
249 define_opaque: None,
250 });
251 let mut rustc_ad_attr =
252 P(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
253
254 let ts2: Vec<TokenTree> = vec![TokenTree::Token(
255 Token::new(TokenKind::Ident(sym::never, false.into()), span),
256 Spacing::Joint,
257 )];
258 let never_arg = ast::DelimArgs {
259 dspan: ast::tokenstream::DelimSpan::from_single(span),
260 delim: ast::token::Delimiter::Parenthesis,
261 tokens: ast::tokenstream::TokenStream::from_iter(ts2),
262 };
263 let inline_item = ast::AttrItem {
264 unsafety: ast::Safety::Default,
265 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
266 args: ast::AttrArgs::Delimited(never_arg),
267 tokens: None,
268 };
269 let inline_never_attr = P(ast::NormalAttr { item: inline_item, tokens: None });
270 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
271 let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
272 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
273 let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
274
275 fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
277 match (attr, item) {
278 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
279 let a = &a.item.path;
280 let b = &b.item.path;
281 a.segments.len() == b.segments.len()
282 && a.segments.iter().zip(b.segments.iter()).all(|(a, b)| a.ident == b.ident)
283 }
284 _ => false,
285 }
286 }
287
288 let orig_annotatable: Annotatable = match item {
290 Annotatable::Item(ref mut iitem) => {
291 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
292 iitem.attrs.push(attr);
293 }
294 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
295 iitem.attrs.push(inline_never.clone());
296 }
297 Annotatable::Item(iitem.clone())
298 }
299 Annotatable::AssocItem(ref mut assoc_item, i @ Impl { of_trait: false }) => {
300 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
301 assoc_item.attrs.push(attr);
302 }
303 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
304 assoc_item.attrs.push(inline_never.clone());
305 }
306 Annotatable::AssocItem(assoc_item.clone(), i)
307 }
308 _ => {
309 unreachable!("annotatable kind checked previously")
310 }
311 };
312 rustc_ad_attr.item.args = rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
314 dspan: DelimSpan::dummy(),
315 delim: rustc_ast::token::Delimiter::Parenthesis,
316 tokens: ts,
317 });
318 let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
319 let d_annotatable = if is_impl {
320 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(asdf);
321 let d_fn = P(ast::AssocItem {
322 attrs: thin_vec![d_attr, inline_never],
323 id: ast::DUMMY_NODE_ID,
324 span,
325 vis,
326 ident: d_ident,
327 kind: assoc_item,
328 tokens: None,
329 });
330 Annotatable::AssocItem(d_fn, Impl { of_trait: false })
331 } else {
332 let mut d_fn =
333 ecx.item(span, d_ident, thin_vec![d_attr, inline_never], ItemKind::Fn(asdf));
334 d_fn.vis = vis;
335 Annotatable::Item(d_fn)
336 };
337
338 return vec![orig_annotatable, d_annotatable];
339 }
340
341 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
344 let mut ty = ty.clone();
345 match ty.kind {
346 TyKind::Ptr(ref mut mut_ty) => {
347 mut_ty.mutbl = ast::Mutability::Mut;
348 }
349 TyKind::Ref(_, ref mut mut_ty) => {
350 mut_ty.mutbl = ast::Mutability::Mut;
351 }
352 _ => {
353 panic!("unsupported type: {:?}", ty);
354 }
355 }
356 ty
357 }
358
359 fn init_body_helper(
371 ecx: &ExtCtxt<'_>,
372 span: Span,
373 primal: Ident,
374 new_names: &[String],
375 sig_span: Span,
376 new_decl_span: Span,
377 idents: &[Ident],
378 errored: bool,
379 ) -> (P<ast::Block>, P<ast::Expr>, P<ast::Expr>, P<ast::Expr>) {
380 let blackbox_path = ecx.std_path(&[sym::hint, sym::black_box]);
381 let noop = ast::InlineAsm {
382 asm_macro: ast::AsmMacro::Asm,
383 template: vec![ast::InlineAsmTemplatePiece::String("NOP".into())],
384 template_strs: Box::new([]),
385 operands: vec![],
386 clobber_abis: vec![],
387 options: ast::InlineAsmOptions::PURE | ast::InlineAsmOptions::NOMEM,
388 line_spans: vec![],
389 };
390 let noop_expr = ecx.expr_asm(span, P(noop));
391 let unsf = ast::BlockCheckMode::Unsafe(ast::UnsafeSource::CompilerGenerated);
392 let unsf_block = ast::Block {
393 stmts: thin_vec![ecx.stmt_semi(noop_expr)],
394 id: ast::DUMMY_NODE_ID,
395 tokens: None,
396 rules: unsf,
397 span,
398 };
399 let unsf_expr = ecx.expr_block(P(unsf_block));
400 let blackbox_call_expr = ecx.expr_path(ecx.path(span, blackbox_path));
401 let primal_call = gen_primal_call(ecx, span, primal, idents);
402 let black_box_primal_call = ecx.expr_call(
403 new_decl_span,
404 blackbox_call_expr.clone(),
405 thin_vec![primal_call.clone()],
406 );
407 let tup_args = new_names
408 .iter()
409 .map(|arg| ecx.expr_path(ecx.path_ident(span, Ident::from_str(arg))))
410 .collect();
411
412 let black_box_remaining_args = ecx.expr_call(
413 sig_span,
414 blackbox_call_expr.clone(),
415 thin_vec![ecx.expr_tuple(sig_span, tup_args)],
416 );
417
418 let mut body = ecx.block(span, ThinVec::new());
419 body.stmts.push(ecx.stmt_semi(unsf_expr));
420
421 if !errored {
423 body.stmts.push(ecx.stmt_semi(black_box_primal_call.clone()));
424 }
425 body.stmts.push(ecx.stmt_semi(black_box_remaining_args));
426
427 (body, primal_call, black_box_primal_call, blackbox_call_expr)
428 }
429
430 fn gen_enzyme_body(
439 ecx: &ExtCtxt<'_>,
440 x: &AutoDiffAttrs,
441 n_active: u32,
442 sig: &ast::FnSig,
443 d_sig: &ast::FnSig,
444 primal: Ident,
445 new_names: &[String],
446 span: Span,
447 sig_span: Span,
448 idents: Vec<Ident>,
449 errored: bool,
450 ) -> P<ast::Block> {
451 let new_decl_span = d_sig.span;
452
453 let (mut body, primal_call, bb_primal_call, bb_call_expr) = init_body_helper(
462 ecx,
463 span,
464 primal,
465 new_names,
466 sig_span,
467 new_decl_span,
468 &idents,
469 errored,
470 );
471
472 if !has_ret(&d_sig.decl.output) {
473 return body;
475 }
476
477 let primal_ret = has_ret(&sig.decl.output) && !x.has_active_only_ret();
480
481 if primal_ret && n_active == 0 && x.mode.is_rev() {
482 body.stmts.push(ecx.stmt_expr(bb_primal_call));
484 return body;
485 }
486
487 if !primal_ret && n_active == 1 {
488 let ty = match d_sig.decl.output {
490 FnRetTy::Ty(ref ty) => ty.clone(),
491 FnRetTy::Default(span) => {
492 panic!("Did not expect Default ret ty: {:?}", span);
493 }
494 };
495 let arg = ty.kind.is_simple_path().unwrap();
496 let sl: Vec<Symbol> = vec![arg, kw::Default];
497 let tmp = ecx.def_site_path(&sl);
498 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
499 let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
500 body.stmts.push(ecx.stmt_expr(default_call_expr));
501 return body;
502 }
503
504 let mut exprs = ThinVec::<P<ast::Expr>>::new();
505 if primal_ret {
506 exprs.push(primal_call);
509 }
510
511 let d_ret_ty = match d_sig.decl.output {
514 FnRetTy::Ty(ref ty) => ty.clone(),
515 FnRetTy::Default(span) => {
516 panic!("Did not expect Default ret ty: {:?}", span);
517 }
518 };
519 let mut d_ret_ty = match d_ret_ty.kind.clone() {
520 TyKind::Tup(ref tys) => tys.clone(),
521 TyKind::Path(_, rustc_ast::Path { segments, .. }) => {
522 if let [segment] = &segments[..]
523 && segment.args.is_none()
524 {
525 let id = vec![segments[0].ident];
526 let kind = TyKind::Path(None, ecx.path(span, id));
527 let ty = P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None });
528 thin_vec![ty]
529 } else {
530 panic!("Expected tuple or simple path return type");
531 }
532 }
533 _ => {
534 panic!("Did not expect non-tuple ret ty: {:?}", d_ret_ty);
536 }
537 };
538
539 if x.mode.is_fwd() && x.ret_activity == DiffActivity::Dual {
540 assert!(d_ret_ty.len() == 2);
541 let arg = d_ret_ty[0].kind.is_simple_path().unwrap();
543 let arg2 = d_ret_ty[1].kind.is_simple_path().unwrap();
544 assert!(arg == arg2);
545 let sl: Vec<Symbol> = vec![arg, kw::Default];
546 let tmp = ecx.def_site_path(&sl);
547 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
548 let default_call_expr = ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
549 exprs.push(default_call_expr);
550 } else if x.mode.is_rev() {
551 if primal_ret {
552 d_ret_ty = d_ret_ty[1..].to_vec().into();
554 }
555
556 for arg in d_ret_ty.iter() {
557 let arg = arg.kind.is_simple_path().unwrap();
558 let sl: Vec<Symbol> = vec![arg, kw::Default];
559 let tmp = ecx.def_site_path(&sl);
560 let default_call_expr = ecx.expr_path(ecx.path(span, tmp));
561 let default_call_expr =
562 ecx.expr_call(new_decl_span, default_call_expr, thin_vec![]);
563 exprs.push(default_call_expr);
564 }
565 }
566
567 let ret: P<ast::Expr>;
568 match &exprs[..] {
569 [] => {
570 assert!(!has_ret(&d_sig.decl.output));
571 return body;
573 }
574 [arg] => {
575 ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![arg.clone()]);
576 }
577 args => {
578 let ret_tuple: P<ast::Expr> = ecx.expr_tuple(span, args.into());
579 ret = ecx.expr_call(new_decl_span, bb_call_expr, thin_vec![ret_tuple]);
580 }
581 }
582 assert!(has_ret(&d_sig.decl.output));
583 body.stmts.push(ecx.stmt_expr(ret));
584
585 body
586 }
587
588 fn gen_primal_call(
589 ecx: &ExtCtxt<'_>,
590 span: Span,
591 primal: Ident,
592 idents: &[Ident],
593 ) -> P<ast::Expr> {
594 let has_self = idents.len() > 0 && idents[0].name == kw::SelfLower;
595 if has_self {
596 let args: ThinVec<_> =
597 idents[1..].iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
598 let self_expr = ecx.expr_self(span);
599 ecx.expr_method_call(span, self_expr, primal, args)
600 } else {
601 let args: ThinVec<_> =
602 idents.iter().map(|arg| ecx.expr_path(ecx.path_ident(span, *arg))).collect();
603 let primal_call_expr = ecx.expr_path(ecx.path_ident(span, primal));
604 ecx.expr_call(span, primal_call_expr, args)
605 }
606 }
607
608 fn gen_enzyme_decl(
620 ecx: &ExtCtxt<'_>,
621 sig: &ast::FnSig,
622 x: &AutoDiffAttrs,
623 span: Span,
624 ) -> (ast::FnSig, Vec<String>, Vec<Ident>, bool) {
625 let dcx = ecx.sess.dcx();
626 let has_ret = has_ret(&sig.decl.output);
627 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
628 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
629 if sig_args != num_activities {
630 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
631 span,
632 expected: sig_args,
633 found: num_activities,
634 });
635 return (sig.clone(), vec![], vec![], true);
637 }
638 assert!(sig.decl.inputs.len() == x.input_activity.len());
639 assert!(has_ret == x.has_ret_activity());
640 let mut d_decl = sig.decl.clone();
641 let mut d_inputs = Vec::new();
642 let mut new_inputs = Vec::new();
643 let mut idents = Vec::new();
644 let mut act_ret = ThinVec::new();
645
646 let mut errors = false;
649 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
650 if !valid_input_activity(x.mode, *activity) {
651 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
652 span,
653 mode: x.mode.to_string(),
654 act: activity.to_string(),
655 });
656 errors = true;
657 }
658 if !valid_ty_for_activity(&arg.ty, *activity) {
659 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
660 span: arg.ty.span,
661 act: activity.to_string(),
662 });
663 errors = true;
664 }
665 }
666
667 if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
668 dcx.emit_err(errors::AutoDiffInvalidRetAct {
669 span,
670 mode: x.mode.to_string(),
671 act: x.ret_activity.to_string(),
672 });
673 }
676
677 if errors {
678 return (sig.clone(), new_inputs, idents, true);
680 }
681
682 let unsafe_activities = x
683 .input_activity
684 .iter()
685 .any(|&act| matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
686 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
687 d_inputs.push(arg.clone());
688 match activity {
689 DiffActivity::Active => {
690 act_ret.push(arg.ty.clone());
691 }
692 DiffActivity::ActiveOnly => {
693 }
696 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
697 let mut shadow_arg = arg.clone();
698 shadow_arg.ty = P(assure_mut_ref(&arg.ty));
700 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
701 ident.name
702 } else {
703 debug!("{:#?}", &shadow_arg.pat);
704 panic!("not an ident?");
705 };
706 let name: String = format!("d{}", old_name);
707 new_inputs.push(name.clone());
708 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
709 shadow_arg.pat = P(ast::Pat {
710 id: ast::DUMMY_NODE_ID,
711 kind: PatKind::Ident(BindingMode::NONE, ident, None),
712 span: shadow_arg.pat.span,
713 tokens: shadow_arg.pat.tokens.clone(),
714 });
715 d_inputs.push(shadow_arg);
716 }
717 DiffActivity::Dual | DiffActivity::DualOnly => {
718 let mut shadow_arg = arg.clone();
719 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
720 ident.name
721 } else {
722 debug!("{:#?}", &shadow_arg.pat);
723 panic!("not an ident?");
724 };
725 let name: String = format!("b{}", old_name);
726 new_inputs.push(name.clone());
727 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
728 shadow_arg.pat = P(ast::Pat {
729 id: ast::DUMMY_NODE_ID,
730 kind: PatKind::Ident(BindingMode::NONE, ident, None),
731 span: shadow_arg.pat.span,
732 tokens: shadow_arg.pat.tokens.clone(),
733 });
734 d_inputs.push(shadow_arg);
735 }
736 DiffActivity::Const => {
737 }
739 DiffActivity::None | DiffActivity::FakeActivitySize => {
740 panic!("Should not happen");
741 }
742 }
743 if let PatKind::Ident(_, ident, _) = arg.pat.kind {
744 idents.push(ident.clone());
745 } else {
746 panic!("not an ident?");
747 }
748 }
749
750 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
751 if active_only_ret {
752 assert!(x.mode.is_rev());
753 }
754
755 if x.mode.is_rev() {
758 match x.ret_activity {
759 DiffActivity::Active | DiffActivity::ActiveOnly => {
760 let ty = match d_decl.output {
761 FnRetTy::Ty(ref ty) => ty.clone(),
762 FnRetTy::Default(span) => {
763 panic!("Did not expect Default ret ty: {:?}", span);
764 }
765 };
766 let name = "dret".to_string();
767 let ident = Ident::from_str_and_span(&name, ty.span);
768 let shadow_arg = ast::Param {
769 attrs: ThinVec::new(),
770 ty: ty.clone(),
771 pat: P(ast::Pat {
772 id: ast::DUMMY_NODE_ID,
773 kind: PatKind::Ident(BindingMode::NONE, ident, None),
774 span: ty.span,
775 tokens: None,
776 }),
777 id: ast::DUMMY_NODE_ID,
778 span: ty.span,
779 is_placeholder: false,
780 };
781 d_inputs.push(shadow_arg);
782 new_inputs.push(name);
783 }
784 _ => {}
785 }
786 }
787 d_decl.inputs = d_inputs.into();
788
789 if x.mode.is_fwd() {
790 if let DiffActivity::Dual = x.ret_activity {
791 let ty = match d_decl.output {
792 FnRetTy::Ty(ref ty) => ty.clone(),
793 FnRetTy::Default(span) => {
794 panic!("Did not expect Default ret ty: {:?}", span);
795 }
796 };
797 let kind = TyKind::Tup(thin_vec![ty.clone(), ty.clone()]);
800 let ty = P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
801 d_decl.output = FnRetTy::Ty(ty);
802 }
803 if let DiffActivity::DualOnly = x.ret_activity {
804 }
808 }
809
810 d_decl.output =
812 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
813
814 trace!("act_ret: {:?}", act_ret);
815
816 if act_ret.len() > 0 {
820 let ret_ty = match d_decl.output {
821 FnRetTy::Ty(ref ty) => {
822 if !active_only_ret {
823 act_ret.insert(0, ty.clone());
824 }
825 let kind = TyKind::Tup(act_ret);
826 P(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
827 }
828 FnRetTy::Default(span) => {
829 if act_ret.len() == 1 {
830 act_ret[0].clone()
831 } else {
832 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
833 P(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
834 }
835 }
836 };
837 d_decl.output = FnRetTy::Ty(ret_ty);
838 }
839
840 let mut d_header = sig.header.clone();
841 if unsafe_activities {
842 d_header.safety = rustc_ast::Safety::Unsafe(span);
843 }
844 let d_sig = FnSig { header: d_header, decl: d_decl, span };
845 trace!("Generated signature: {:?}", d_sig);
846 (d_sig, new_inputs, idents, false)
847 }
848}
849
850pub(crate) use llvm_enzyme::expand;