1mod llvm_enzyme {
7 use std::str::FromStr;
8 use std::string::String;
9
10 use rustc_ast::expand::autodiff_attrs::{
11 DiffActivity, DiffMode, valid_input_activity, valid_ret_activity, valid_ty_for_activity,
12 };
13 use rustc_ast::token::{Lit, LitKind, Token, TokenKind};
14 use rustc_ast::tokenstream::*;
15 use rustc_ast::visit::AssocCtxt::*;
16 use rustc_ast::{
17 self as ast, AngleBracketedArg, AngleBracketedArgs, AnonConst, AssocItemKind, BindingMode,
18 FnRetTy, FnSig, GenericArg, GenericArgs, GenericParamKind, Generics, ItemKind,
19 MetaItemInner, MgcaDisambiguation, PatKind, Path, PathSegment, TyKind, Visibility,
20 };
21 use rustc_expand::base::{Annotatable, ExtCtxt};
22 use rustc_hir::attrs::RustcAutodiff;
23 use rustc_span::{Ident, Span, Symbol, kw, sym};
24 use thin_vec::{ThinVec, thin_vec};
25 use tracing::{debug, trace};
26
27 use crate::errors;
28
29 pub(crate) fn outer_normal_attr(
30 kind: &Box<rustc_ast::NormalAttr>,
31 id: rustc_ast::AttrId,
32 span: Span,
33 ) -> rustc_ast::Attribute {
34 let style = rustc_ast::AttrStyle::Outer;
35 let kind = rustc_ast::AttrKind::Normal(kind.clone());
36 rustc_ast::Attribute { kind, id, style, span }
37 }
38
39 fn has_ret(ty: &FnRetTy) -> bool {
42 match ty {
43 FnRetTy::Ty(ty) => !ty.kind.is_unit(),
44 FnRetTy::Default(_) => false,
45 }
46 }
47 fn first_ident(x: &MetaItemInner) -> rustc_span::Ident {
48 if let Some(l) = x.lit() {
49 match l.kind {
50 ast::LitKind::Int(val, _) => {
51 return rustc_span::Ident::from_str(val.get().to_string().as_str());
53 }
54 _ => {}
55 }
56 }
57
58 let segments = &x.meta_item().unwrap().path.segments;
59 if !(segments.len() == 1) {
::core::panicking::panic("assertion failed: segments.len() == 1")
};assert!(segments.len() == 1);
60 segments[0].ident
61 }
62
63 fn name(x: &MetaItemInner) -> String {
64 first_ident(x).name.to_string()
65 }
66
67 fn width(x: &MetaItemInner) -> Option<u128> {
68 let lit = x.lit()?;
69 match lit.kind {
70 ast::LitKind::Int(x, _) => Some(x.get()),
71 _ => return None,
72 }
73 }
74
75 fn extract_item_info(iitem: &Box<ast::Item>) -> Option<(Visibility, FnSig, Ident, Generics)> {
77 match &iitem.kind {
78 ItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => {
79 Some((iitem.vis.clone(), sig.clone(), ident.clone(), generics.clone()))
80 }
81 _ => None,
82 }
83 }
84
85 pub(crate) fn from_ast(
86 ecx: &mut ExtCtxt<'_>,
87 meta_item: &ThinVec<MetaItemInner>,
88 has_ret: bool,
89 mode: DiffMode,
90 ) -> RustcAutodiff {
91 let dcx = ecx.sess.dcx();
92
93 let mut first_activity = 1;
96
97 let width = if let [_, x, ..] = &meta_item[..]
98 && let Some(x) = width(x)
99 {
100 first_activity = 2;
101 match x.try_into() {
102 Ok(x) => x,
103 Err(_) => {
104 dcx.emit_err(errors::AutoDiffInvalidWidth {
105 span: meta_item[1].span(),
106 width: x,
107 });
108 return RustcAutodiff::error();
109 }
110 }
111 } else {
112 1
113 };
114
115 let mut activities: Vec<DiffActivity> = ::alloc::vec::Vec::new()vec![];
116 let mut errors = false;
117 for x in &meta_item[first_activity..] {
118 let activity_str = name(&x);
119 let res = DiffActivity::from_str(&activity_str);
120 match res {
121 Ok(x) => activities.push(x),
122 Err(_) => {
123 dcx.emit_err(errors::AutoDiffUnknownActivity {
124 span: x.span(),
125 act: activity_str,
126 });
127 errors = true;
128 }
129 };
130 }
131 if errors {
132 return RustcAutodiff::error();
133 }
134
135 let (ret_activity, input_activity) = if has_ret {
138 let Some((last, rest)) = activities.split_last() else {
139 {
::core::panicking::panic_fmt(format_args!("internal error: entered unreachable code: {0}",
format_args!("should not be reachable because we counted the number of activities previously")));
};unreachable!(
140 "should not be reachable because we counted the number of activities previously"
141 );
142 };
143 (last, rest)
144 } else {
145 (&DiffActivity::None, activities.as_slice())
146 };
147
148 RustcAutodiff {
149 mode,
150 width,
151 ret_activity: *ret_activity,
152 input_activity: input_activity.iter().cloned().collect(),
153 }
154 }
155
156 fn meta_item_inner_to_ts(t: &MetaItemInner, ts: &mut Vec<TokenTree>) {
157 let comma: Token = Token::new(TokenKind::Comma, Span::default());
158 let val = first_ident(t);
159 let t = Token::from_ast_ident(val);
160 ts.push(TokenTree::Token(t, Spacing::Joint));
161 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
162 }
163
164 pub(crate) fn expand_forward(
165 ecx: &mut ExtCtxt<'_>,
166 expand_span: Span,
167 meta_item: &ast::MetaItem,
168 item: Annotatable,
169 ) -> Vec<Annotatable> {
170 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Forward)
171 }
172
173 pub(crate) fn expand_reverse(
174 ecx: &mut ExtCtxt<'_>,
175 expand_span: Span,
176 meta_item: &ast::MetaItem,
177 item: Annotatable,
178 ) -> Vec<Annotatable> {
179 expand_with_mode(ecx, expand_span, meta_item, item, DiffMode::Reverse)
180 }
181
182 pub(crate) fn expand_with_mode(
206 ecx: &mut ExtCtxt<'_>,
207 expand_span: Span,
208 meta_item: &ast::MetaItem,
209 mut item: Annotatable,
210 mode: DiffMode,
211 ) -> Vec<Annotatable> {
212 let dcx = ecx.sess.dcx();
213
214 let Some((vis, sig, primal, generics, is_impl)) = (match &item {
218 Annotatable::Item(iitem) => {
219 extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
220 }
221 Annotatable::Stmt(stmt) => match &stmt.kind {
222 ast::StmtKind::Item(iitem) => {
223 extract_item_info(iitem).map(|(v, s, p, g)| (v, s, p, g, false))
224 }
225 _ => None,
226 },
227 Annotatable::AssocItem(assoc_item, _ctxt @ (Impl { of_trait: _ } | Trait)) => {
228 match &assoc_item.kind {
229 ast::AssocItemKind::Fn(box ast::Fn { sig, ident, generics, .. }) => Some((
230 assoc_item.vis.clone(),
231 sig.clone(),
232 ident.clone(),
233 generics.clone(),
234 true,
235 )),
236 _ => None,
237 }
238 }
239 _ => None,
240 }) else {
241 dcx.emit_err(errors::AutoDiffInvalidApplication { span: item.span() });
242 return ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[item]))vec![item];
243 };
244
245 let meta_item_vec: ThinVec<MetaItemInner> = match meta_item.kind {
246 ast::MetaItemKind::List(ref vec) => vec.clone(),
247 _ => {
248 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
249 return ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[item]))vec![item];
250 }
251 };
252
253 let has_ret = has_ret(&sig.decl.output);
254
255 let mut ts: Vec<TokenTree> = ::alloc::vec::Vec::new()vec![];
258 if meta_item_vec.len() < 1 {
259 dcx.emit_err(errors::AutoDiffMissingConfig { span: item.span() });
261 return ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[item]))vec![item];
262 }
263
264 let mode_symbol = match mode {
265 DiffMode::Forward => sym::Forward,
266 DiffMode::Reverse => sym::Reverse,
267 _ => {
::core::panicking::panic_fmt(format_args!("internal error: entered unreachable code: {0}",
format_args!("Unsupported mode: {0:?}", mode)));
}unreachable!("Unsupported mode: {:?}", mode),
268 };
269
270 let mode_token = Token::new(TokenKind::Ident(mode_symbol, false.into()), Span::default());
272 ts.insert(0, TokenTree::Token(mode_token, Spacing::Joint));
273 ts.insert(
274 1,
275 TokenTree::Token(Token::new(TokenKind::Comma, Span::default()), Spacing::Alone),
276 );
277
278 let start_position;
281 let kind: LitKind = LitKind::Integer;
282 let symbol;
283 if meta_item_vec.len() >= 2
284 && let Some(width) = width(&meta_item_vec[1])
285 {
286 start_position = 2;
287 symbol = Symbol::intern(&width.to_string());
288 } else {
289 start_position = 1;
290 symbol = sym::integer(1);
291 }
292
293 let l: Lit = Lit { kind, symbol, suffix: None };
294 let t = Token::new(TokenKind::Literal(l), Span::default());
295 let comma = Token::new(TokenKind::Comma, Span::default());
296 ts.push(TokenTree::Token(t, Spacing::Joint));
297 ts.push(TokenTree::Token(comma.clone(), Spacing::Alone));
298
299 for t in meta_item_vec.clone()[start_position..].iter() {
300 meta_item_inner_to_ts(t, &mut ts);
301 }
302
303 if !has_ret {
304 let t = Token::new(TokenKind::Ident(sym::None, false.into()), Span::default());
307 ts.push(TokenTree::Token(t, Spacing::Joint));
308 ts.push(TokenTree::Token(comma, Spacing::Alone));
309 }
310 ts.pop();
312 let ts: TokenStream = TokenStream::from_iter(ts);
313
314 let x: RustcAutodiff = from_ast(ecx, &meta_item_vec, has_ret, mode);
315 if !x.is_active() {
316 return ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[item]))vec![item];
319 }
320 let span = ecx.with_def_site_ctxt(expand_span);
321
322 let d_sig = gen_enzyme_decl(ecx, &sig, &x, span);
323
324 let d_body = ecx.block(
325 span,
326 {
let len = [()].len();
let mut vec = ::thin_vec::ThinVec::with_capacity(len);
vec.push(call_autodiff(ecx, primal, first_ident(&meta_item_vec[0]), span,
&sig, &d_sig, &generics, is_impl));
vec
}thin_vec![call_autodiff(
327 ecx,
328 primal,
329 first_ident(&meta_item_vec[0]),
330 span,
331 &sig,
332 &d_sig,
333 &generics,
334 is_impl,
335 )],
336 );
337
338 let d_fn = Box::new(ast::Fn {
340 defaultness: ast::Defaultness::Implicit,
341 sig: d_sig,
342 ident: first_ident(&meta_item_vec[0]),
343 generics,
344 contract: None,
345 body: Some(d_body),
346 define_opaque: None,
347 eii_impls: ThinVec::new(),
348 });
349 let mut rustc_ad_attr =
350 Box::new(ast::NormalAttr::from_ident(Ident::with_dummy_span(sym::rustc_autodiff)));
351
352 let ts2: Vec<TokenTree> = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[TokenTree::Token(Token::new(TokenKind::Ident(sym::never,
false.into()), span), Spacing::Joint)]))vec![TokenTree::Token(
353 Token::new(TokenKind::Ident(sym::never, false.into()), span),
354 Spacing::Joint,
355 )];
356 let never_arg = ast::DelimArgs {
357 dspan: DelimSpan::from_single(span),
358 delim: ast::token::Delimiter::Parenthesis,
359 tokens: TokenStream::from_iter(ts2),
360 };
361 let inline_item = ast::AttrItem {
362 unsafety: ast::Safety::Default,
363 path: ast::Path::from_ident(Ident::with_dummy_span(sym::inline)),
364 args: rustc_ast::ast::AttrItemKind::Unparsed(ast::AttrArgs::Delimited(never_arg)),
365 tokens: None,
366 };
367 let inline_never_attr = Box::new(ast::NormalAttr { item: inline_item, tokens: None });
368 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
369 let attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
370 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
371 let inline_never = outer_normal_attr(&inline_never_attr, new_id, span);
372
373 fn same_attribute(attr: &ast::AttrKind, item: &ast::AttrKind) -> bool {
375 match (attr, item) {
376 (ast::AttrKind::Normal(a), ast::AttrKind::Normal(b)) => {
377 let a = &a.item.path;
378 let b = &b.item.path;
379 a.segments.iter().eq_by(&b.segments, |a, b| a.ident == b.ident)
380 }
381 _ => false,
382 }
383 }
384
385 let mut has_inline_never = false;
386
387 let orig_annotatable: Annotatable = match item {
389 Annotatable::Item(ref mut iitem) => {
390 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
391 iitem.attrs.push(attr);
392 }
393 if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
394 has_inline_never = true;
395 }
396 Annotatable::Item(iitem.clone())
397 }
398 Annotatable::AssocItem(ref mut assoc_item, ctxt @ (Impl { .. } | Trait)) => {
399 if !assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
400 assoc_item.attrs.push(attr);
401 }
402 if assoc_item.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
403 has_inline_never = true;
404 }
405 Annotatable::AssocItem(assoc_item.clone(), ctxt)
406 }
407 Annotatable::Stmt(ref mut stmt) => {
408 match stmt.kind {
409 ast::StmtKind::Item(ref mut iitem) => {
410 if !iitem.attrs.iter().any(|a| same_attribute(&a.kind, &attr.kind)) {
411 iitem.attrs.push(attr);
412 }
413 if iitem.attrs.iter().any(|a| same_attribute(&a.kind, &inline_never.kind)) {
414 has_inline_never = true;
415 }
416 }
417 _ => {
::core::panicking::panic_fmt(format_args!("internal error: entered unreachable code: {0}",
format_args!("stmt kind checked previously")));
}unreachable!("stmt kind checked previously"),
418 };
419
420 Annotatable::Stmt(stmt.clone())
421 }
422 _ => {
423 {
::core::panicking::panic_fmt(format_args!("internal error: entered unreachable code: {0}",
format_args!("annotatable kind checked previously")));
}unreachable!("annotatable kind checked previously")
424 }
425 };
426 rustc_ad_attr.item.args = rustc_ast::ast::AttrItemKind::Unparsed(
428 rustc_ast::AttrArgs::Delimited(rustc_ast::DelimArgs {
429 dspan: DelimSpan::dummy(),
430 delim: rustc_ast::token::Delimiter::Parenthesis,
431 tokens: ts,
432 }),
433 );
434
435 let new_id = ecx.sess.psess.attr_id_generator.mk_attr_id();
436 let d_attr = outer_normal_attr(&rustc_ad_attr, new_id, span);
437
438 let mut d_attrs = {
let len = [()].len();
let mut vec = ::thin_vec::ThinVec::with_capacity(len);
vec.push(d_attr);
vec
}thin_vec![d_attr];
440
441 if has_inline_never {
442 d_attrs.push(inline_never);
443 }
444
445 let d_annotatable = match &item {
446 Annotatable::AssocItem(_, ctxt) => {
447 let assoc_item: AssocItemKind = ast::AssocItemKind::Fn(d_fn);
448 let d_fn = Box::new(ast::AssocItem {
449 attrs: d_attrs,
450 id: ast::DUMMY_NODE_ID,
451 span,
452 vis,
453 kind: assoc_item,
454 tokens: None,
455 });
456 Annotatable::AssocItem(d_fn, *ctxt)
457 }
458 Annotatable::Item(_) => {
459 let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
460 d_fn.vis = vis;
461
462 Annotatable::Item(d_fn)
463 }
464 Annotatable::Stmt(_) => {
465 let mut d_fn = ecx.item(span, d_attrs, ItemKind::Fn(d_fn));
466 d_fn.vis = vis;
467
468 Annotatable::Stmt(Box::new(ast::Stmt {
469 id: ast::DUMMY_NODE_ID,
470 kind: ast::StmtKind::Item(d_fn),
471 span,
472 }))
473 }
474 _ => {
475 {
::core::panicking::panic_fmt(format_args!("internal error: entered unreachable code: {0}",
format_args!("item kind checked previously")));
}unreachable!("item kind checked previously")
476 }
477 };
478
479 return ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[orig_annotatable, d_annotatable]))vec![orig_annotatable, d_annotatable];
480 }
481
482 fn assure_mut_ref(ty: &ast::Ty) -> ast::Ty {
485 let mut ty = ty.clone();
486 match ty.kind {
487 TyKind::Ptr(ref mut mut_ty) => {
488 mut_ty.mutbl = ast::Mutability::Mut;
489 }
490 TyKind::Ref(_, ref mut mut_ty) => {
491 mut_ty.mutbl = ast::Mutability::Mut;
492 }
493 _ => {
494 {
::core::panicking::panic_fmt(format_args!("unsupported type: {0:?}", ty));
};panic!("unsupported type: {:?}", ty);
495 }
496 }
497 ty
498 }
499
500 fn call_autodiff(
505 ecx: &ExtCtxt<'_>,
506 primal: Ident,
507 diff: Ident,
508 span: Span,
509 p_sig: &FnSig,
510 d_sig: &FnSig,
511 generics: &Generics,
512 is_impl: bool,
513 ) -> rustc_ast::Stmt {
514 let primal_path_expr = gen_turbofish_expr(ecx, primal, generics, span, is_impl);
515
516 let self_ty = || ecx.ty_path(ast::Path::from_ident(Ident::with_dummy_span(kw::SelfUpper)));
517 let fn_ptr_params: ThinVec<ast::Param> = p_sig
518 .decl
519 .inputs
520 .iter()
521 .map(|param| {
522 let ty = match ¶m.ty.kind {
523 TyKind::ImplicitSelf => self_ty(),
524 TyKind::Ref(lt, mt) if #[allow(non_exhaustive_omitted_patterns)] match mt.ty.kind {
TyKind::ImplicitSelf => true,
_ => false,
}matches!(mt.ty.kind, TyKind::ImplicitSelf) => ecx.ty(
525 span,
526 TyKind::Ref(lt.clone(), ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }),
527 ),
528 TyKind::Ptr(mt) if #[allow(non_exhaustive_omitted_patterns)] match mt.ty.kind {
TyKind::ImplicitSelf => true,
_ => false,
}matches!(mt.ty.kind, TyKind::ImplicitSelf) => {
529 ecx.ty(span, TyKind::Ptr(ast::MutTy { ty: self_ty(), mutbl: mt.mutbl }))
530 }
531 _ => param.ty.clone(),
532 };
533 ast::Param {
534 attrs: ast::AttrVec::new(),
535 ty,
536 pat: Box::new(ecx.pat_wild(span)),
537 id: ast::DUMMY_NODE_ID,
538 span,
539 is_placeholder: false,
540 }
541 })
542 .collect();
543 let fn_ptr_ty = ecx.ty(
544 span,
545 TyKind::FnPtr(Box::new(ast::FnPtrTy {
546 safety: p_sig.header.safety,
547 ext: p_sig.header.ext,
548 generic_params: ThinVec::new(),
549 decl: Box::new(ast::FnDecl {
550 inputs: fn_ptr_params,
551 output: p_sig.decl.output.clone(),
552 }),
553 decl_span: span,
554 })),
555 );
556 let primal_fn_ptr = ecx.expr(span, ast::ExprKind::Cast(primal_path_expr, fn_ptr_ty));
557
558 let diff_path_expr = gen_turbofish_expr(ecx, diff, generics, span, is_impl);
559
560 let tuple_expr = ecx.expr_tuple(
561 span,
562 d_sig
563 .decl
564 .inputs
565 .iter()
566 .map(|arg| match arg.pat.kind {
567 PatKind::Ident(_, ident, _) => ecx.expr_path(ecx.path_ident(span, ident)),
568 _ => ::core::panicking::panic("not yet implemented")todo!(),
569 })
570 .collect::<ThinVec<_>>()
571 .into(),
572 );
573
574 let enzyme_path_idents = ecx.std_path(&[sym::intrinsics, sym::autodiff]);
575 let enzyme_path = ecx.path(span, enzyme_path_idents);
576 let call_expr = ecx.expr_call(
577 span,
578 ecx.expr_path(enzyme_path),
579 ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[primal_fn_ptr, diff_path_expr, tuple_expr]))vec![primal_fn_ptr, diff_path_expr, tuple_expr].into(),
580 );
581
582 ecx.stmt_expr(call_expr)
583 }
584
585 fn gen_turbofish_expr(
589 ecx: &ExtCtxt<'_>,
590 ident: Ident,
591 generics: &Generics,
592 span: Span,
593 is_impl: bool,
594 ) -> Box<ast::Expr> {
595 let generic_args = generics
596 .params
597 .iter()
598 .filter_map(|p| match &p.kind {
599 GenericParamKind::Type { .. } => {
600 let path = ast::Path::from_ident(p.ident);
601 let ty = ecx.ty_path(path);
602 Some(AngleBracketedArg::Arg(GenericArg::Type(ty)))
603 }
604 GenericParamKind::Const { .. } => {
605 let expr = ecx.expr_path(ast::Path::from_ident(p.ident));
606 let anon_const = AnonConst {
607 id: ast::DUMMY_NODE_ID,
608 value: expr,
609 mgca_disambiguation: MgcaDisambiguation::Direct,
610 };
611 Some(AngleBracketedArg::Arg(GenericArg::Const(anon_const)))
612 }
613 GenericParamKind::Lifetime { .. } => None,
614 })
615 .collect::<ThinVec<_>>();
616
617 let args: AngleBracketedArgs = AngleBracketedArgs { span, args: generic_args };
618
619 let segment = PathSegment {
620 ident,
621 id: ast::DUMMY_NODE_ID,
622 args: Some(Box::new(GenericArgs::AngleBracketed(args))),
623 };
624
625 let segments = if is_impl {
626 {
let len = [(), ()].len();
let mut vec = ::thin_vec::ThinVec::with_capacity(len);
vec.push(PathSegment {
ident: Ident::from_str("Self"),
id: ast::DUMMY_NODE_ID,
args: None,
});
vec.push(segment);
vec
}thin_vec![
627 PathSegment { ident: Ident::from_str("Self"), id: ast::DUMMY_NODE_ID, args: None },
628 segment,
629 ]
630 } else {
631 {
let len = [()].len();
let mut vec = ::thin_vec::ThinVec::with_capacity(len);
vec.push(segment);
vec
}thin_vec![segment]
632 };
633
634 let path = Path { span, segments, tokens: None };
635
636 ecx.expr_path(path)
637 }
638
639 fn gen_enzyme_decl(
651 ecx: &ExtCtxt<'_>,
652 sig: &ast::FnSig,
653 x: &RustcAutodiff,
654 span: Span,
655 ) -> ast::FnSig {
656 let dcx = ecx.sess.dcx();
657 let has_ret = has_ret(&sig.decl.output);
658 let sig_args = sig.decl.inputs.len() + if has_ret { 1 } else { 0 };
659 let num_activities = x.input_activity.len() + if x.has_ret_activity() { 1 } else { 0 };
660 if sig_args != num_activities {
661 dcx.emit_err(errors::AutoDiffInvalidNumberActivities {
662 span,
663 expected: sig_args,
664 found: num_activities,
665 });
666 return sig.clone();
668 }
669 if !(sig.decl.inputs.len() == x.input_activity.len()) {
::core::panicking::panic("assertion failed: sig.decl.inputs.len() == x.input_activity.len()")
};assert!(sig.decl.inputs.len() == x.input_activity.len());
670 if !(has_ret == x.has_ret_activity()) {
::core::panicking::panic("assertion failed: has_ret == x.has_ret_activity()")
};assert!(has_ret == x.has_ret_activity());
671 let mut d_decl = sig.decl.clone();
672 let mut d_inputs = Vec::new();
673 let mut new_inputs = Vec::new();
674 let mut idents = Vec::new();
675 let mut act_ret = ThinVec::new();
676
677 let mut errors = false;
680 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
681 if !valid_input_activity(x.mode, *activity) {
682 dcx.emit_err(errors::AutoDiffInvalidApplicationModeAct {
683 span,
684 mode: x.mode.to_string(),
685 act: activity.to_string(),
686 });
687 errors = true;
688 }
689 if !valid_ty_for_activity(&arg.ty, *activity) {
690 dcx.emit_err(errors::AutoDiffInvalidTypeForActivity {
691 span: arg.ty.span,
692 act: activity.to_string(),
693 });
694 errors = true;
695 }
696 }
697
698 if has_ret && !valid_ret_activity(x.mode, x.ret_activity) {
699 dcx.emit_err(errors::AutoDiffInvalidRetAct {
700 span,
701 mode: x.mode.to_string(),
702 act: x.ret_activity.to_string(),
703 });
704 }
707
708 if errors {
709 return sig.clone();
711 }
712
713 let unsafe_activities = x
714 .input_activity
715 .iter()
716 .any(|&act| #[allow(non_exhaustive_omitted_patterns)] match act {
DiffActivity::DuplicatedOnly | DiffActivity::DualOnly => true,
_ => false,
}matches!(act, DiffActivity::DuplicatedOnly | DiffActivity::DualOnly));
717 for (arg, activity) in sig.decl.inputs.iter().zip(x.input_activity.iter()) {
718 d_inputs.push(arg.clone());
719 match activity {
720 DiffActivity::Active => {
721 act_ret.push(arg.ty.clone());
722 }
724 DiffActivity::ActiveOnly => {
725 }
728 DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => {
729 for i in 0..x.width {
730 let mut shadow_arg = arg.clone();
731 shadow_arg.ty = Box::new(assure_mut_ref(&arg.ty));
733 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
734 ident.name
735 } else {
736 {
use ::tracing::__macro_support::Callsite as _;
static __CALLSITE: ::tracing::callsite::DefaultCallsite =
{
static META: ::tracing::Metadata<'static> =
{
::tracing_core::metadata::Metadata::new("event compiler/rustc_builtin_macros/src/autodiff.rs:736",
"rustc_builtin_macros::autodiff::llvm_enzyme",
::tracing::Level::DEBUG,
::tracing_core::__macro_support::Option::Some("compiler/rustc_builtin_macros/src/autodiff.rs"),
::tracing_core::__macro_support::Option::Some(736u32),
::tracing_core::__macro_support::Option::Some("rustc_builtin_macros::autodiff::llvm_enzyme"),
::tracing_core::field::FieldSet::new(&["message"],
::tracing_core::callsite::Identifier(&__CALLSITE)),
::tracing::metadata::Kind::EVENT)
};
::tracing::callsite::DefaultCallsite::new(&META)
};
let enabled =
::tracing::Level::DEBUG <= ::tracing::level_filters::STATIC_MAX_LEVEL
&&
::tracing::Level::DEBUG <=
::tracing::level_filters::LevelFilter::current() &&
{
let interest = __CALLSITE.interest();
!interest.is_never() &&
::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
interest)
};
if enabled {
(|value_set: ::tracing::field::ValueSet|
{
let meta = __CALLSITE.metadata();
::tracing::Event::dispatch(meta, &value_set);
;
})({
#[allow(unused_imports)]
use ::tracing::field::{debug, display, Value};
let mut iter = __CALLSITE.metadata().fields().iter();
__CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
::tracing::__macro_support::Option::Some(&format_args!("{0:#?}",
&shadow_arg.pat) as &dyn Value))])
});
} else { ; }
};debug!("{:#?}", &shadow_arg.pat);
737 { ::core::panicking::panic_fmt(format_args!("not an ident?")); };panic!("not an ident?");
738 };
739 let name: String = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("d{0}_{1}", old_name, i))
})format!("d{}_{}", old_name, i);
740 new_inputs.push(name.clone());
741 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
742 shadow_arg.pat = Box::new(ast::Pat {
743 id: ast::DUMMY_NODE_ID,
744 kind: PatKind::Ident(BindingMode::NONE, ident, None),
745 span: shadow_arg.pat.span,
746 tokens: shadow_arg.pat.tokens.clone(),
747 });
748 d_inputs.push(shadow_arg.clone());
749 }
750 }
751 DiffActivity::Dual
752 | DiffActivity::DualOnly
753 | DiffActivity::Dualv
754 | DiffActivity::DualvOnly => {
755 let iterations =
758 if #[allow(non_exhaustive_omitted_patterns)] match activity {
DiffActivity::Dualv | DiffActivity::DualvOnly => true,
_ => false,
}matches!(activity, DiffActivity::Dualv | DiffActivity::DualvOnly) {
759 1
760 } else {
761 x.width
762 };
763 for i in 0..iterations {
764 let mut shadow_arg = arg.clone();
765 let old_name = if let PatKind::Ident(_, ident, _) = arg.pat.kind {
766 ident.name
767 } else {
768 {
use ::tracing::__macro_support::Callsite as _;
static __CALLSITE: ::tracing::callsite::DefaultCallsite =
{
static META: ::tracing::Metadata<'static> =
{
::tracing_core::metadata::Metadata::new("event compiler/rustc_builtin_macros/src/autodiff.rs:768",
"rustc_builtin_macros::autodiff::llvm_enzyme",
::tracing::Level::DEBUG,
::tracing_core::__macro_support::Option::Some("compiler/rustc_builtin_macros/src/autodiff.rs"),
::tracing_core::__macro_support::Option::Some(768u32),
::tracing_core::__macro_support::Option::Some("rustc_builtin_macros::autodiff::llvm_enzyme"),
::tracing_core::field::FieldSet::new(&["message"],
::tracing_core::callsite::Identifier(&__CALLSITE)),
::tracing::metadata::Kind::EVENT)
};
::tracing::callsite::DefaultCallsite::new(&META)
};
let enabled =
::tracing::Level::DEBUG <= ::tracing::level_filters::STATIC_MAX_LEVEL
&&
::tracing::Level::DEBUG <=
::tracing::level_filters::LevelFilter::current() &&
{
let interest = __CALLSITE.interest();
!interest.is_never() &&
::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
interest)
};
if enabled {
(|value_set: ::tracing::field::ValueSet|
{
let meta = __CALLSITE.metadata();
::tracing::Event::dispatch(meta, &value_set);
;
})({
#[allow(unused_imports)]
use ::tracing::field::{debug, display, Value};
let mut iter = __CALLSITE.metadata().fields().iter();
__CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
::tracing::__macro_support::Option::Some(&format_args!("{0:#?}",
&shadow_arg.pat) as &dyn Value))])
});
} else { ; }
};debug!("{:#?}", &shadow_arg.pat);
769 { ::core::panicking::panic_fmt(format_args!("not an ident?")); };panic!("not an ident?");
770 };
771 let name: String = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!("b{0}_{1}", old_name, i))
})format!("b{}_{}", old_name, i);
772 new_inputs.push(name.clone());
773 let ident = Ident::from_str_and_span(&name, shadow_arg.pat.span);
774 shadow_arg.pat = Box::new(ast::Pat {
775 id: ast::DUMMY_NODE_ID,
776 kind: PatKind::Ident(BindingMode::NONE, ident, None),
777 span: shadow_arg.pat.span,
778 tokens: shadow_arg.pat.tokens.clone(),
779 });
780 d_inputs.push(shadow_arg.clone());
781 }
782 }
783 DiffActivity::Const => {
784 }
786 DiffActivity::None | DiffActivity::FakeActivitySize(_) => {
787 { ::core::panicking::panic_fmt(format_args!("Should not happen")); };panic!("Should not happen");
788 }
789 }
790 if let PatKind::Ident(_, ident, _) = arg.pat.kind {
791 idents.push(ident.clone());
792 } else {
793 { ::core::panicking::panic_fmt(format_args!("not an ident?")); };panic!("not an ident?");
794 }
795 }
796
797 let active_only_ret = x.ret_activity == DiffActivity::ActiveOnly;
798 if active_only_ret {
799 if !x.mode.is_rev() {
::core::panicking::panic("assertion failed: x.mode.is_rev()")
};assert!(x.mode.is_rev());
800 }
801
802 if x.mode.is_rev() {
805 match x.ret_activity {
806 DiffActivity::Active | DiffActivity::ActiveOnly => {
807 let ty = match d_decl.output {
808 FnRetTy::Ty(ref ty) => ty.clone(),
809 FnRetTy::Default(span) => {
810 {
::core::panicking::panic_fmt(format_args!("Did not expect Default ret ty: {0:?}",
span));
};panic!("Did not expect Default ret ty: {:?}", span);
811 }
812 };
813 let name = "dret".to_string();
814 let ident = Ident::from_str_and_span(&name, ty.span);
815 let shadow_arg = ast::Param {
816 attrs: ThinVec::new(),
817 ty: ty.clone(),
818 pat: Box::new(ast::Pat {
819 id: ast::DUMMY_NODE_ID,
820 kind: PatKind::Ident(BindingMode::NONE, ident, None),
821 span: ty.span,
822 tokens: None,
823 }),
824 id: ast::DUMMY_NODE_ID,
825 span: ty.span,
826 is_placeholder: false,
827 };
828 d_inputs.push(shadow_arg);
829 new_inputs.push(name);
830 }
831 _ => {}
832 }
833 }
834 d_decl.inputs = d_inputs.into();
835
836 if x.mode.is_fwd() {
837 let ty = match d_decl.output {
838 FnRetTy::Ty(ref ty) => ty.clone(),
839 FnRetTy::Default(span) => {
840 let kind = TyKind::Tup(ThinVec::new());
842 let ty = Box::new(rustc_ast::Ty {
843 kind,
844 id: ast::DUMMY_NODE_ID,
845 span,
846 tokens: None,
847 });
848 d_decl.output = FnRetTy::Ty(ty.clone());
849 if !#[allow(non_exhaustive_omitted_patterns)] match x.ret_activity {
DiffActivity::None => true,
_ => false,
} {
::core::panicking::panic("assertion failed: matches!(x.ret_activity, DiffActivity::None)")
};assert!(matches!(x.ret_activity, DiffActivity::None));
850 ty
852 }
853 };
854
855 if #[allow(non_exhaustive_omitted_patterns)] match x.ret_activity {
DiffActivity::Dual | DiffActivity::Dualv => true,
_ => false,
}matches!(x.ret_activity, DiffActivity::Dual | DiffActivity::Dualv) {
856 let kind = if x.width == 1 || #[allow(non_exhaustive_omitted_patterns)] match x.ret_activity {
DiffActivity::Dualv => true,
_ => false,
}matches!(x.ret_activity, DiffActivity::Dualv) {
857 TyKind::Tup({
let len = [(), ()].len();
let mut vec = ::thin_vec::ThinVec::with_capacity(len);
vec.push(ty.clone());
vec.push(ty.clone());
vec
}thin_vec![ty.clone(), ty.clone()])
860 } else {
861 let anon_const = rustc_ast::AnonConst {
863 id: ast::DUMMY_NODE_ID,
864 value: ecx.expr_usize(span, 1 + x.width as usize),
865 mgca_disambiguation: MgcaDisambiguation::Direct,
866 };
867 TyKind::Array(ty.clone(), anon_const)
868 };
869 let ty = Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
870 d_decl.output = FnRetTy::Ty(ty);
871 }
872 if #[allow(non_exhaustive_omitted_patterns)] match x.ret_activity {
DiffActivity::DualOnly | DiffActivity::DualvOnly => true,
_ => false,
}matches!(x.ret_activity, DiffActivity::DualOnly | DiffActivity::DualvOnly) {
873 if x.width > 1 {
877 let anon_const = rustc_ast::AnonConst {
878 id: ast::DUMMY_NODE_ID,
879 value: ecx.expr_usize(span, x.width as usize),
880 mgca_disambiguation: MgcaDisambiguation::Direct,
881 };
882 let kind = TyKind::Array(ty.clone(), anon_const);
883 let ty =
884 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None });
885 d_decl.output = FnRetTy::Ty(ty);
886 }
887 }
888 }
889
890 d_decl.output =
892 if active_only_ret { FnRetTy::Default(span) } else { d_decl.output.clone() };
893
894 {
use ::tracing::__macro_support::Callsite as _;
static __CALLSITE: ::tracing::callsite::DefaultCallsite =
{
static META: ::tracing::Metadata<'static> =
{
::tracing_core::metadata::Metadata::new("event compiler/rustc_builtin_macros/src/autodiff.rs:894",
"rustc_builtin_macros::autodiff::llvm_enzyme",
::tracing::Level::TRACE,
::tracing_core::__macro_support::Option::Some("compiler/rustc_builtin_macros/src/autodiff.rs"),
::tracing_core::__macro_support::Option::Some(894u32),
::tracing_core::__macro_support::Option::Some("rustc_builtin_macros::autodiff::llvm_enzyme"),
::tracing_core::field::FieldSet::new(&["message"],
::tracing_core::callsite::Identifier(&__CALLSITE)),
::tracing::metadata::Kind::EVENT)
};
::tracing::callsite::DefaultCallsite::new(&META)
};
let enabled =
::tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
&&
::tracing::Level::TRACE <=
::tracing::level_filters::LevelFilter::current() &&
{
let interest = __CALLSITE.interest();
!interest.is_never() &&
::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
interest)
};
if enabled {
(|value_set: ::tracing::field::ValueSet|
{
let meta = __CALLSITE.metadata();
::tracing::Event::dispatch(meta, &value_set);
;
})({
#[allow(unused_imports)]
use ::tracing::field::{debug, display, Value};
let mut iter = __CALLSITE.metadata().fields().iter();
__CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
::tracing::__macro_support::Option::Some(&format_args!("act_ret: {0:?}",
act_ret) as &dyn Value))])
});
} else { ; }
};trace!("act_ret: {:?}", act_ret);
895
896 if act_ret.len() > 0 {
900 let ret_ty = match d_decl.output {
901 FnRetTy::Ty(ref ty) => {
902 if !active_only_ret {
903 act_ret.insert(0, ty.clone());
904 }
905 let kind = TyKind::Tup(act_ret);
906 Box::new(rustc_ast::Ty { kind, id: ty.id, span: ty.span, tokens: None })
907 }
908 FnRetTy::Default(span) => {
909 if act_ret.len() == 1 {
910 act_ret[0].clone()
911 } else {
912 let kind = TyKind::Tup(act_ret.iter().map(|arg| arg.clone()).collect());
913 Box::new(rustc_ast::Ty { kind, id: ast::DUMMY_NODE_ID, span, tokens: None })
914 }
915 }
916 };
917 d_decl.output = FnRetTy::Ty(ret_ty);
918 }
919
920 let mut d_header = sig.header.clone();
921 if unsafe_activities {
922 d_header.safety = rustc_ast::Safety::Unsafe(span);
923 }
924 let d_sig = FnSig { header: d_header, decl: d_decl, span };
925 {
use ::tracing::__macro_support::Callsite as _;
static __CALLSITE: ::tracing::callsite::DefaultCallsite =
{
static META: ::tracing::Metadata<'static> =
{
::tracing_core::metadata::Metadata::new("event compiler/rustc_builtin_macros/src/autodiff.rs:925",
"rustc_builtin_macros::autodiff::llvm_enzyme",
::tracing::Level::TRACE,
::tracing_core::__macro_support::Option::Some("compiler/rustc_builtin_macros/src/autodiff.rs"),
::tracing_core::__macro_support::Option::Some(925u32),
::tracing_core::__macro_support::Option::Some("rustc_builtin_macros::autodiff::llvm_enzyme"),
::tracing_core::field::FieldSet::new(&["message"],
::tracing_core::callsite::Identifier(&__CALLSITE)),
::tracing::metadata::Kind::EVENT)
};
::tracing::callsite::DefaultCallsite::new(&META)
};
let enabled =
::tracing::Level::TRACE <= ::tracing::level_filters::STATIC_MAX_LEVEL
&&
::tracing::Level::TRACE <=
::tracing::level_filters::LevelFilter::current() &&
{
let interest = __CALLSITE.interest();
!interest.is_never() &&
::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
interest)
};
if enabled {
(|value_set: ::tracing::field::ValueSet|
{
let meta = __CALLSITE.metadata();
::tracing::Event::dispatch(meta, &value_set);
;
})({
#[allow(unused_imports)]
use ::tracing::field::{debug, display, Value};
let mut iter = __CALLSITE.metadata().fields().iter();
__CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
::tracing::__macro_support::Option::Some(&format_args!("Generated signature: {0:?}",
d_sig) as &dyn Value))])
});
} else { ; }
};trace!("Generated signature: {:?}", d_sig);
926 d_sig
927 }
928}
929
930pub(crate) use llvm_enzyme::{expand_forward, expand_reverse};