rustc_monomorphize/partitioning/
autodiff.rs

1use rustc_ast::expand::autodiff_attrs::{AutoDiffItem, DiffActivity};
2use rustc_hir::def_id::LOCAL_CRATE;
3use rustc_middle::bug;
4use rustc_middle::mir::mono::MonoItem;
5use rustc_middle::ty::{self, Instance, Ty, TyCtxt};
6use rustc_symbol_mangling::symbol_name_for_instance_in_crate;
7use tracing::{debug, trace};
8
9use crate::partitioning::UsageMap;
10
11fn adjust_activity_to_abi<'tcx>(tcx: TyCtxt<'tcx>, fn_ty: Ty<'tcx>, da: &mut Vec<DiffActivity>) {
12    if !matches!(fn_ty.kind(), ty::FnDef(..)) {
13        bug!("expected fn def for autodiff, got {:?}", fn_ty);
14    }
15    let fnc_binder: ty::Binder<'_, ty::FnSig<'_>> = fn_ty.fn_sig(tcx);
16
17    // If rustc compiles the unmodified primal, we know that this copy of the function
18    // also has correct lifetimes. We know that Enzyme won't free the shadow too early
19    // (or actually at all), so let's strip lifetimes when computing the layout.
20    let x = tcx.instantiate_bound_regions_with_erased(fnc_binder);
21    let mut new_activities = vec![];
22    let mut new_positions = vec![];
23    for (i, ty) in x.inputs().iter().enumerate() {
24        if let Some(inner_ty) = ty.builtin_deref(true) {
25            if ty.is_fn_ptr() {
26                // FIXME(ZuseZ4): add a nicer error, or just figure out how to support them,
27                // since Enzyme itself can handle them.
28                tcx.dcx().err("function pointers are currently not supported in autodiff");
29            }
30            if inner_ty.is_slice() {
31                // We know that the length will be passed as extra arg.
32                if !da.is_empty() {
33                    // We are looking at a slice. The length of that slice will become an
34                    // extra integer on llvm level. Integers are always const.
35                    // However, if the slice get's duplicated, we want to know to later check the
36                    // size. So we mark the new size argument as FakeActivitySize.
37                    let activity = match da[i] {
38                        DiffActivity::DualOnly
39                        | DiffActivity::Dual
40                        | DiffActivity::DuplicatedOnly
41                        | DiffActivity::Duplicated => DiffActivity::FakeActivitySize,
42                        DiffActivity::Const => DiffActivity::Const,
43                        _ => bug!("unexpected activity for ptr/ref"),
44                    };
45                    new_activities.push(activity);
46                    new_positions.push(i + 1);
47                }
48                continue;
49            }
50        }
51    }
52    // now add the extra activities coming from slices
53    // Reverse order to not invalidate the indices
54    for _ in 0..new_activities.len() {
55        let pos = new_positions.pop().unwrap();
56        let activity = new_activities.pop().unwrap();
57        da.insert(pos, activity);
58    }
59}
60
61pub(crate) fn find_autodiff_source_functions<'tcx>(
62    tcx: TyCtxt<'tcx>,
63    usage_map: &UsageMap<'tcx>,
64    autodiff_mono_items: Vec<(&MonoItem<'tcx>, &Instance<'tcx>)>,
65) -> Vec<AutoDiffItem> {
66    let mut autodiff_items: Vec<AutoDiffItem> = vec![];
67    for (item, instance) in autodiff_mono_items {
68        let target_id = instance.def_id();
69        let cg_fn_attr = tcx.codegen_fn_attrs(target_id).autodiff_item.clone();
70        let Some(target_attrs) = cg_fn_attr else {
71            continue;
72        };
73        let mut input_activities: Vec<DiffActivity> = target_attrs.input_activity.clone();
74        if target_attrs.is_source() {
75            trace!("source found: {:?}", target_id);
76        }
77        if !target_attrs.apply_autodiff() {
78            continue;
79        }
80
81        let target_symbol = symbol_name_for_instance_in_crate(tcx, instance.clone(), LOCAL_CRATE);
82
83        let source =
84            usage_map.used_map.get(&item).unwrap().into_iter().find_map(|item| match *item {
85                MonoItem::Fn(ref instance_s) => {
86                    let source_id = instance_s.def_id();
87                    if let Some(ad) = &tcx.codegen_fn_attrs(source_id).autodiff_item
88                        && ad.is_active()
89                    {
90                        return Some(instance_s);
91                    }
92                    None
93                }
94                _ => None,
95            });
96        let inst = match source {
97            Some(source) => source,
98            None => continue,
99        };
100
101        debug!("source_id: {:?}", inst.def_id());
102        let fn_ty = inst.ty(tcx, ty::TypingEnv::fully_monomorphized());
103        assert!(fn_ty.is_fn());
104        adjust_activity_to_abi(tcx, fn_ty, &mut input_activities);
105        let symb = symbol_name_for_instance_in_crate(tcx, inst.clone(), LOCAL_CRATE);
106
107        let mut new_target_attrs = target_attrs.clone();
108        new_target_attrs.input_activity = input_activities;
109        let itm = new_target_attrs.into_item(symb, target_symbol);
110        autodiff_items.push(itm);
111    }
112
113    if !autodiff_items.is_empty() {
114        trace!("AUTODIFF ITEMS EXIST");
115        for item in &mut *autodiff_items {
116            trace!("{}", &item);
117        }
118    }
119
120    autodiff_items
121}