Skip to main content

rustc_codegen_llvm/builder/
autodiff.rs

1use std::ptr;
2
3use rustc_ast::expand::autodiff_attrs::{DiffActivity, DiffMode};
4use rustc_ast::expand::typetree::FncTree;
5use rustc_codegen_ssa::common::TypeKind;
6use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7use rustc_data_structures::thin_vec::ThinVec;
8use rustc_hir::attrs::RustcAutodiff;
9use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
10use rustc_middle::{bug, ty};
11use rustc_target::callconv::PassMode;
12use tracing::debug;
13
14use crate::builder::{Builder, PlaceRef, UNNAMED};
15use crate::context::SimpleCx;
16use crate::declare::declare_simple_fn;
17use crate::llvm::{self, TRUE, Type, Value};
18
19pub(crate) fn adjust_activity_to_abi<'tcx>(
20    tcx: TyCtxt<'tcx>,
21    instance: Instance<'tcx>,
22    typing_env: TypingEnv<'tcx>,
23    da: &mut ThinVec<DiffActivity>,
24) {
25    let fn_ty = instance.ty(tcx, typing_env);
26
27    if !#[allow(non_exhaustive_omitted_patterns)] match fn_ty.kind() {
    ty::FnDef(..) => true,
    _ => false,
}matches!(fn_ty.kind(), ty::FnDef(..)) {
28        ::rustc_middle::util::bug::bug_fmt(format_args!("expected fn def for autodiff, got {0:?}",
        fn_ty));bug!("expected fn def for autodiff, got {:?}", fn_ty);
29    }
30
31    // We don't actually pass the types back into the type system.
32    // All we do is decide how to handle the arguments.
33    let sig = fn_ty.fn_sig(tcx).skip_binder();
34
35    // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
36    let Ok(fn_abi) =
37        tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
38    else {
39        ::rustc_middle::util::bug::bug_fmt(format_args!("failed to get fn_abi of instance with empty varargs"));bug!("failed to get fn_abi of instance with empty varargs");
40    };
41
42    let mut new_activities = ::alloc::vec::Vec::new()vec![];
43    let mut new_positions = ::alloc::vec::Vec::new()vec![];
44    let mut del_activities = 0;
45    for (i, ty) in sig.inputs().iter().enumerate() {
46        if let Some(inner_ty) = ty.builtin_deref(true) {
47            if inner_ty.is_slice() {
48                // Now we need to figure out the size of each slice element in memory to allow
49                // safety checks and usability improvements in the backend.
50                let sty = match inner_ty.builtin_index() {
51                    Some(sty) => sty,
52                    None => {
53                        { ::core::panicking::panic_fmt(format_args!("slice element type unknown")); };panic!("slice element type unknown");
54                    }
55                };
56                let pci = PseudoCanonicalInput {
57                    typing_env: TypingEnv::fully_monomorphized(),
58                    value: sty,
59                };
60
61                let layout = tcx.layout_of(pci);
62                let elem_size = match layout {
63                    Ok(layout) => layout.size,
64                    Err(_) => {
65                        ::rustc_middle::util::bug::bug_fmt(format_args!("autodiff failed to compute slice element size"));bug!("autodiff failed to compute slice element size");
66                    }
67                };
68                let elem_size: u32 = elem_size.bytes() as u32;
69
70                // We know that the length will be passed as extra arg.
71                if !da.is_empty() {
72                    // We are looking at a slice. The length of that slice will become an
73                    // extra integer on llvm level. Integers are always const.
74                    // However, if the slice get's duplicated, we want to know to later check the
75                    // size. So we mark the new size argument as FakeActivitySize.
76                    // There is one FakeActivitySize per slice, so for convenience we store the
77                    // slice element size in bytes in it. We will use the size in the backend.
78                    let activity = match da[i] {
79                        DiffActivity::DualOnly
80                        | DiffActivity::Dual
81                        | DiffActivity::Dualv
82                        | DiffActivity::DuplicatedOnly
83                        | DiffActivity::Duplicated => {
84                            DiffActivity::FakeActivitySize(Some(elem_size))
85                        }
86                        DiffActivity::Const => DiffActivity::Const,
87                        _ => ::rustc_middle::util::bug::bug_fmt(format_args!("unexpected activity for ptr/ref"))bug!("unexpected activity for ptr/ref"),
88                    };
89                    new_activities.push(activity);
90                    new_positions.push(i + 1);
91                }
92
93                continue;
94            }
95        }
96
97        let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
98
99        let layout = match tcx.layout_of(pci) {
100            Ok(layout) => layout.layout,
101            Err(_) => {
102                ::rustc_middle::util::bug::bug_fmt(format_args!("failed to compute layout for type {0:?}",
        ty));bug!("failed to compute layout for type {:?}", ty);
103            }
104        };
105
106        let pass_mode = &fn_abi.args[i].mode;
107
108        // For ZST, just ignore and don't add its activity, as this arg won't be present
109        // in the LLVM passed to Enzyme.
110        // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
111        // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
112        if *pass_mode == PassMode::Ignore {
113            del_activities += 1;
114            da.remove(i);
115        }
116
117        // If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
118        // Otherwise, the number of activities won't match the number of LLVM arguments and
119        // this will lead to errors when verifying the Enzyme call.
120        if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
121            new_activities.push(da[i].clone());
122            new_positions.push(i + 1 - del_activities);
123        }
124    }
125    // now add the extra activities coming from slices
126    // Reverse order to not invalidate the indices
127    for _ in 0..new_activities.len() {
128        let pos = new_positions.pop().unwrap();
129        let activity = new_activities.pop().unwrap();
130        da.insert(pos, activity);
131    }
132}
133
134// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
135// original inputs, as well as metadata and the additional shadow arguments.
136// This function matches the arguments from the outer function to the inner enzyme call.
137//
138// This function also considers that Rust level arguments not always match the llvm-ir level
139// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
140// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
141// need to match those.
142// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
143// using iterators and peek()?
144fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
145    cx: &SimpleCx<'ll>,
146    builder: &mut Builder<'_, 'll, 'tcx>,
147    width: u32,
148    args: &mut Vec<&'ll Value>,
149    inputs: &[DiffActivity],
150    outer_args: &[&'ll Value],
151) {
152    {
    use ::tracing::__macro_support::Callsite as _;
    static __CALLSITE: ::tracing::callsite::DefaultCallsite =
        {
            static META: ::tracing::Metadata<'static> =
                {
                    ::tracing_core::metadata::Metadata::new("event compiler/rustc_codegen_llvm/src/builder/autodiff.rs:152",
                        "rustc_codegen_llvm::builder::autodiff",
                        ::tracing::Level::DEBUG,
                        ::tracing_core::__macro_support::Option::Some("compiler/rustc_codegen_llvm/src/builder/autodiff.rs"),
                        ::tracing_core::__macro_support::Option::Some(152u32),
                        ::tracing_core::__macro_support::Option::Some("rustc_codegen_llvm::builder::autodiff"),
                        ::tracing_core::field::FieldSet::new(&["message"],
                            ::tracing_core::callsite::Identifier(&__CALLSITE)),
                        ::tracing::metadata::Kind::EVENT)
                };
            ::tracing::callsite::DefaultCallsite::new(&META)
        };
    let enabled =
        ::tracing::Level::DEBUG <= ::tracing::level_filters::STATIC_MAX_LEVEL
                &&
                ::tracing::Level::DEBUG <=
                    ::tracing::level_filters::LevelFilter::current() &&
            {
                let interest = __CALLSITE.interest();
                !interest.is_never() &&
                    ::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
                        interest)
            };
    if enabled {
        (|value_set: ::tracing::field::ValueSet|
                    {
                        let meta = __CALLSITE.metadata();
                        ::tracing::Event::dispatch(meta, &value_set);
                        ;
                    })({
                #[allow(unused_imports)]
                use ::tracing::field::{debug, display, Value};
                let mut iter = __CALLSITE.metadata().fields().iter();
                __CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
                                    ::tracing::__macro_support::Option::Some(&format_args!("matching autodiff arguments")
                                            as &dyn Value))])
            });
    } else { ; }
};debug!("matching autodiff arguments");
153    // We now handle the issue that Rust level arguments not always match the llvm-ir level
154    // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
155    // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
156    // need to match those.
157    // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
158    // using iterators and peek()?
159    let mut outer_pos: usize = 0;
160    let mut activity_pos = 0;
161
162    // We used to use llvm's metadata to instruct enzyme how to differentiate a function.
163    // In debug mode we would use incremental compilation which caused the metadata to be
164    // dropped. This is prevented by now using named globals, which are also understood
165    // by Enzyme.
166    let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
167    let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
168    let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
169    let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
170    let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
171    let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());
172
173    while activity_pos < inputs.len() {
174        let diff_activity = inputs[activity_pos as usize];
175        // Duplicated arguments received a shadow argument, into which enzyme will write the
176        // gradient.
177        let (activity, duplicated): (&Value, bool) = match diff_activity {
178            DiffActivity::None => { ::core::panicking::panic_fmt(format_args!("not a valid input activity")); }panic!("not a valid input activity"),
179            DiffActivity::Const => (global_const, false),
180            DiffActivity::Active => (global_out, false),
181            DiffActivity::ActiveOnly => (global_out, false),
182            DiffActivity::Dual => (global_dup, true),
183            DiffActivity::Dualv => (global_dupv, true),
184            DiffActivity::DualOnly => (global_dupnoneed, true),
185            DiffActivity::DualvOnly => (global_dupnoneedv, true),
186            DiffActivity::Duplicated => (global_dup, true),
187            DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
188            DiffActivity::FakeActivitySize(_) => (global_const, false),
189        };
190        let outer_arg = outer_args[outer_pos];
191        args.push(activity);
192        if #[allow(non_exhaustive_omitted_patterns)] match diff_activity {
    DiffActivity::Dualv => true,
    _ => false,
}matches!(diff_activity, DiffActivity::Dualv) {
193            let next_outer_arg = outer_args[outer_pos + 1];
194            let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
195                DiffActivity::FakeActivitySize(Some(s)) => s.into(),
196                _ => ::rustc_middle::util::bug::bug_fmt(format_args!("incorrect Dualv handling recognized."))bug!("incorrect Dualv handling recognized."),
197            };
198            // stride: sizeof(T) * n_elems.
199            // n_elems is the next integer.
200            // Now we multiply `4 * next_outer_arg` to get the stride.
201            let mul = unsafe {
202                llvm::LLVMBuildMul(
203                    builder.llbuilder,
204                    cx.get_const_int(cx.type_i64(), elem_bytes_size),
205                    next_outer_arg,
206                    UNNAMED,
207                )
208            };
209            args.push(mul);
210        }
211        args.push(outer_arg);
212        if duplicated {
213            // We know that duplicated args by construction have a following argument,
214            // so this can not be out of bounds.
215            let next_outer_arg = outer_args[outer_pos + 1];
216            let next_outer_ty = cx.val_ty(next_outer_arg);
217            // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
218            // vectors behind references (&Vec<T>) are already supported. Users can not pass a
219            // Vec by value for reverse mode, so this would only help forward mode autodiff.
220            let slice = {
221                if activity_pos + 1 >= inputs.len() {
222                    // If there is no arg following our ptr, it also can't be a slice,
223                    // since that would lead to a ptr, int pair.
224                    false
225                } else {
226                    let next_activity = inputs[activity_pos + 1];
227                    // We analyze the MIR types and add this dummy activity if we visit a slice.
228                    #[allow(non_exhaustive_omitted_patterns)] match next_activity {
    DiffActivity::FakeActivitySize(_) => true,
    _ => false,
}matches!(next_activity, DiffActivity::FakeActivitySize(_))
229                }
230            };
231            if slice {
232                // A duplicated slice will have the following two outer_fn arguments:
233                // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
234                // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
235                // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
236                // int2 >= int1, which means the shadow vector is large enough to store the gradient.
237                match (&cx.type_kind(next_outer_ty), &TypeKind::Integer) {
    (left_val, right_val) => {
        if !(*left_val == *right_val) {
            let kind = ::core::panicking::AssertKind::Eq;
            ::core::panicking::assert_failed(kind, &*left_val, &*right_val,
                ::core::option::Option::None);
        }
    }
};assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
238
239                let iterations =
240                    if #[allow(non_exhaustive_omitted_patterns)] match diff_activity {
    DiffActivity::Dualv => true,
    _ => false,
}matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
241
242                for i in 0..iterations {
243                    let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
244                    let next_outer_ty2 = cx.val_ty(next_outer_arg2);
245                    match (&cx.type_kind(next_outer_ty2), &TypeKind::Pointer) {
    (left_val, right_val) => {
        if !(*left_val == *right_val) {
            let kind = ::core::panicking::AssertKind::Eq;
            ::core::panicking::assert_failed(kind, &*left_val, &*right_val,
                ::core::option::Option::None);
        }
    }
};assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
246                    let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
247                    let next_outer_ty3 = cx.val_ty(next_outer_arg3);
248                    match (&cx.type_kind(next_outer_ty3), &TypeKind::Integer) {
    (left_val, right_val) => {
        if !(*left_val == *right_val) {
            let kind = ::core::panicking::AssertKind::Eq;
            ::core::panicking::assert_failed(kind, &*left_val, &*right_val,
                ::core::option::Option::None);
        }
    }
};assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
249                    args.push(next_outer_arg2);
250                }
251                args.push(global_const);
252                args.push(next_outer_arg);
253                outer_pos += 2 + 2 * iterations;
254                activity_pos += 2;
255            } else {
256                // A duplicated pointer will have the following two outer_fn arguments:
257                // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
258                // (..., metadata! enzyme_dup, ptr, ptr, ...).
259                if #[allow(non_exhaustive_omitted_patterns)] match diff_activity {
    DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => true,
    _ => false,
}matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
260                {
261                    match (&cx.type_kind(next_outer_ty), &TypeKind::Pointer) {
    (left_val, right_val) => {
        if !(*left_val == *right_val) {
            let kind = ::core::panicking::AssertKind::Eq;
            ::core::panicking::assert_failed(kind, &*left_val, &*right_val,
                ::core::option::Option::None);
        }
    }
};assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
262                }
263                // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
264                args.push(next_outer_arg);
265                outer_pos += 2;
266                activity_pos += 1;
267
268                // Now, if width > 1, we need to account for that
269                for _ in 1..width {
270                    let next_outer_arg = outer_args[outer_pos];
271                    args.push(next_outer_arg);
272                    outer_pos += 1;
273                }
274            }
275        } else {
276            // We do not differentiate with resprect to this argument.
277            // We already added the metadata and argument above, so just increase the counters.
278            outer_pos += 1;
279            activity_pos += 1;
280        }
281    }
282}
283
284/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
285/// function with expected naming and calling conventions[^1] which will be
286/// discovered by the enzyme LLVM pass and its body populated with the differentiated
287/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
288/// function and handle the differences between the Rust calling convention and
289/// Enzyme.
290/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
291// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
292// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
293pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
294    builder: &mut Builder<'_, 'll, 'tcx>,
295    cx: &SimpleCx<'ll>,
296    fn_to_diff: &'ll Value,
297    outer_name: &str,
298    ret_ty: &'ll Type,
299    fn_args: &[&'ll Value],
300    attrs: &RustcAutodiff,
301    dest: PlaceRef<'tcx, &'ll Value>,
302    fnc_tree: FncTree,
303) {
304    // We have to pick the name depending on whether we want forward or reverse mode autodiff.
305    let mut ad_name: String = match attrs.mode {
306        DiffMode::Forward => "__enzyme_fwddiff",
307        DiffMode::Reverse => "__enzyme_autodiff",
308        _ => {
    ::core::panicking::panic_fmt(format_args!("logic bug in autodiff, unrecognized mode"));
}panic!("logic bug in autodiff, unrecognized mode"),
309    }
310    .to_string();
311
312    // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
313    // functions. Unwrap will only panic, if LLVM gave us an invalid string.
314    ad_name.push_str(outer_name);
315
316    // Let us assume the user wrote the following function square:
317    //
318    // ```llvm
319    // define double @square(double %x) {
320    // entry:
321    //  %0 = fmul double %x, %x
322    //  ret double %0
323    // }
324    //
325    // define double @dsquare(double %x) {
326    //  return 0.0;
327    // }
328    // ```
329    //
330    // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
331    // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
332    // Again, the arguments to all functions are slightly simplified.
333    // ```llvm
334    // declare double @__enzyme_autodiff_square(...)
335    //
336    // define double @dsquare(double %x) {
337    // entry:
338    //   %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
339    //   ret double %0
340    // }
341    // ```
342    let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, TRUE) };
343
344    // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
345    // think a bit more about what should go here.
346    let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) };
347    let ad_fn = declare_simple_fn(
348        cx,
349        &ad_name,
350        llvm::CallConv::try_from(cc).expect("invalid callconv"),
351        llvm::UnnamedAddr::No,
352        llvm::Visibility::Default,
353        enzyme_ty,
354    );
355
356    let num_args = llvm::LLVMCountParams(&fn_to_diff);
357    let mut args = Vec::with_capacity(num_args as usize + 1);
358    args.push(fn_to_diff);
359
360    let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
361    if #[allow(non_exhaustive_omitted_patterns)] match attrs.ret_activity {
    DiffActivity::Dual | DiffActivity::Active => true,
    _ => false,
}matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
362        args.push(global_primal_ret);
363    }
364    if attrs.width > 1 {
365        let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
366        args.push(global_width);
367        args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
368    }
369
370    match_args_from_caller_to_enzyme(
371        &cx,
372        builder,
373        attrs.width,
374        &mut args,
375        &attrs.input_activity,
376        fn_args,
377    );
378
379    if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
380        crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
381    }
382
383    let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
384
385    let fn_ret_ty = builder.cx.val_ty(call);
386    if fn_ret_ty != builder.cx.type_void() && fn_ret_ty != builder.cx.type_struct(&[], false) {
387        // If we return void or an empty struct, then our caller (due to how we generated it)
388        // does not expect a return value. As such, we have no pointer (or place) into which
389        // we could store our value, and would store into an undef, which would cause UB.
390        // As such, we just ignore the return value in those cases.
391        builder.store_to_place(call, dest.val);
392    }
393}