rustc_codegen_llvm/builder/
autodiff.rs

1use std::ptr;
2
3use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
4use rustc_ast::expand::typetree::FncTree;
5use rustc_codegen_ssa::common::TypeKind;
6use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
8use rustc_middle::{bug, ty};
9use rustc_target::callconv::PassMode;
10use tracing::debug;
11
12use crate::builder::{Builder, PlaceRef, UNNAMED};
13use crate::context::SimpleCx;
14use crate::declare::declare_simple_fn;
15use crate::llvm::{self, TRUE, Type, Value};
16
17pub(crate) fn adjust_activity_to_abi<'tcx>(
18    tcx: TyCtxt<'tcx>,
19    instance: Instance<'tcx>,
20    typing_env: TypingEnv<'tcx>,
21    da: &mut Vec<DiffActivity>,
22) {
23    let fn_ty = instance.ty(tcx, typing_env);
24
25    if !matches!(fn_ty.kind(), ty::FnDef(..)) {
26        bug!("expected fn def for autodiff, got {:?}", fn_ty);
27    }
28
29    // We don't actually pass the types back into the type system.
30    // All we do is decide how to handle the arguments.
31    let sig = fn_ty.fn_sig(tcx).skip_binder();
32
33    // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
34    let Ok(fn_abi) =
35        tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
36    else {
37        bug!("failed to get fn_abi of instance with empty varargs");
38    };
39
40    let mut new_activities = vec![];
41    let mut new_positions = vec![];
42    let mut del_activities = 0;
43    for (i, ty) in sig.inputs().iter().enumerate() {
44        if let Some(inner_ty) = ty.builtin_deref(true) {
45            if inner_ty.is_slice() {
46                // Now we need to figure out the size of each slice element in memory to allow
47                // safety checks and usability improvements in the backend.
48                let sty = match inner_ty.builtin_index() {
49                    Some(sty) => sty,
50                    None => {
51                        panic!("slice element type unknown");
52                    }
53                };
54                let pci = PseudoCanonicalInput {
55                    typing_env: TypingEnv::fully_monomorphized(),
56                    value: sty,
57                };
58
59                let layout = tcx.layout_of(pci);
60                let elem_size = match layout {
61                    Ok(layout) => layout.size,
62                    Err(_) => {
63                        bug!("autodiff failed to compute slice element size");
64                    }
65                };
66                let elem_size: u32 = elem_size.bytes() as u32;
67
68                // We know that the length will be passed as extra arg.
69                if !da.is_empty() {
70                    // We are looking at a slice. The length of that slice will become an
71                    // extra integer on llvm level. Integers are always const.
72                    // However, if the slice get's duplicated, we want to know to later check the
73                    // size. So we mark the new size argument as FakeActivitySize.
74                    // There is one FakeActivitySize per slice, so for convenience we store the
75                    // slice element size in bytes in it. We will use the size in the backend.
76                    let activity = match da[i] {
77                        DiffActivity::DualOnly
78                        | DiffActivity::Dual
79                        | DiffActivity::Dualv
80                        | DiffActivity::DuplicatedOnly
81                        | DiffActivity::Duplicated => {
82                            DiffActivity::FakeActivitySize(Some(elem_size))
83                        }
84                        DiffActivity::Const => DiffActivity::Const,
85                        _ => bug!("unexpected activity for ptr/ref"),
86                    };
87                    new_activities.push(activity);
88                    new_positions.push(i + 1);
89                }
90
91                continue;
92            }
93        }
94
95        let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
96
97        let layout = match tcx.layout_of(pci) {
98            Ok(layout) => layout.layout,
99            Err(_) => {
100                bug!("failed to compute layout for type {:?}", ty);
101            }
102        };
103
104        let pass_mode = &fn_abi.args[i].mode;
105
106        // For ZST, just ignore and don't add its activity, as this arg won't be present
107        // in the LLVM passed to Enzyme.
108        // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
109        // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
110        if *pass_mode == PassMode::Ignore {
111            del_activities += 1;
112            da.remove(i);
113        }
114
115        // If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
116        // Otherwise, the number of activities won't match the number of LLVM arguments and
117        // this will lead to errors when verifying the Enzyme call.
118        if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
119            new_activities.push(da[i].clone());
120            new_positions.push(i + 1 - del_activities);
121        }
122    }
123    // now add the extra activities coming from slices
124    // Reverse order to not invalidate the indices
125    for _ in 0..new_activities.len() {
126        let pos = new_positions.pop().unwrap();
127        let activity = new_activities.pop().unwrap();
128        da.insert(pos, activity);
129    }
130}
131
132// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
133// original inputs, as well as metadata and the additional shadow arguments.
134// This function matches the arguments from the outer function to the inner enzyme call.
135//
136// This function also considers that Rust level arguments not always match the llvm-ir level
137// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
138// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
139// need to match those.
140// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
141// using iterators and peek()?
142fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
143    cx: &SimpleCx<'ll>,
144    builder: &mut Builder<'_, 'll, 'tcx>,
145    width: u32,
146    args: &mut Vec<&'ll Value>,
147    inputs: &[DiffActivity],
148    outer_args: &[&'ll Value],
149) {
150    debug!("matching autodiff arguments");
151    // We now handle the issue that Rust level arguments not always match the llvm-ir level
152    // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
153    // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
154    // need to match those.
155    // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
156    // using iterators and peek()?
157    let mut outer_pos: usize = 0;
158    let mut activity_pos = 0;
159
160    // We used to use llvm's metadata to instruct enzyme how to differentiate a function.
161    // In debug mode we would use incremental compilation which caused the metadata to be
162    // dropped. This is prevented by now using named globals, which are also understood
163    // by Enzyme.
164    let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
165    let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
166    let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
167    let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
168    let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
169    let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());
170
171    while activity_pos < inputs.len() {
172        let diff_activity = inputs[activity_pos as usize];
173        // Duplicated arguments received a shadow argument, into which enzyme will write the
174        // gradient.
175        let (activity, duplicated): (&Value, bool) = match diff_activity {
176            DiffActivity::None => panic!("not a valid input activity"),
177            DiffActivity::Const => (global_const, false),
178            DiffActivity::Active => (global_out, false),
179            DiffActivity::ActiveOnly => (global_out, false),
180            DiffActivity::Dual => (global_dup, true),
181            DiffActivity::Dualv => (global_dupv, true),
182            DiffActivity::DualOnly => (global_dupnoneed, true),
183            DiffActivity::DualvOnly => (global_dupnoneedv, true),
184            DiffActivity::Duplicated => (global_dup, true),
185            DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
186            DiffActivity::FakeActivitySize(_) => (global_const, false),
187        };
188        let outer_arg = outer_args[outer_pos];
189        args.push(activity);
190        if matches!(diff_activity, DiffActivity::Dualv) {
191            let next_outer_arg = outer_args[outer_pos + 1];
192            let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
193                DiffActivity::FakeActivitySize(Some(s)) => s.into(),
194                _ => bug!("incorrect Dualv handling recognized."),
195            };
196            // stride: sizeof(T) * n_elems.
197            // n_elems is the next integer.
198            // Now we multiply `4 * next_outer_arg` to get the stride.
199            let mul = unsafe {
200                llvm::LLVMBuildMul(
201                    builder.llbuilder,
202                    cx.get_const_int(cx.type_i64(), elem_bytes_size),
203                    next_outer_arg,
204                    UNNAMED,
205                )
206            };
207            args.push(mul);
208        }
209        args.push(outer_arg);
210        if duplicated {
211            // We know that duplicated args by construction have a following argument,
212            // so this can not be out of bounds.
213            let next_outer_arg = outer_args[outer_pos + 1];
214            let next_outer_ty = cx.val_ty(next_outer_arg);
215            // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
216            // vectors behind references (&Vec<T>) are already supported. Users can not pass a
217            // Vec by value for reverse mode, so this would only help forward mode autodiff.
218            let slice = {
219                if activity_pos + 1 >= inputs.len() {
220                    // If there is no arg following our ptr, it also can't be a slice,
221                    // since that would lead to a ptr, int pair.
222                    false
223                } else {
224                    let next_activity = inputs[activity_pos + 1];
225                    // We analyze the MIR types and add this dummy activity if we visit a slice.
226                    matches!(next_activity, DiffActivity::FakeActivitySize(_))
227                }
228            };
229            if slice {
230                // A duplicated slice will have the following two outer_fn arguments:
231                // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
232                // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
233                // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
234                // int2 >= int1, which means the shadow vector is large enough to store the gradient.
235                assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
236
237                let iterations =
238                    if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
239
240                for i in 0..iterations {
241                    let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
242                    let next_outer_ty2 = cx.val_ty(next_outer_arg2);
243                    assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
244                    let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
245                    let next_outer_ty3 = cx.val_ty(next_outer_arg3);
246                    assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
247                    args.push(next_outer_arg2);
248                }
249                args.push(global_const);
250                args.push(next_outer_arg);
251                outer_pos += 2 + 2 * iterations;
252                activity_pos += 2;
253            } else {
254                // A duplicated pointer will have the following two outer_fn arguments:
255                // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
256                // (..., metadata! enzyme_dup, ptr, ptr, ...).
257                if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
258                {
259                    assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
260                }
261                // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
262                args.push(next_outer_arg);
263                outer_pos += 2;
264                activity_pos += 1;
265
266                // Now, if width > 1, we need to account for that
267                for _ in 1..width {
268                    let next_outer_arg = outer_args[outer_pos];
269                    args.push(next_outer_arg);
270                    outer_pos += 1;
271                }
272            }
273        } else {
274            // We do not differentiate with resprect to this argument.
275            // We already added the metadata and argument above, so just increase the counters.
276            outer_pos += 1;
277            activity_pos += 1;
278        }
279    }
280}
281
282/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
283/// function with expected naming and calling conventions[^1] which will be
284/// discovered by the enzyme LLVM pass and its body populated with the differentiated
285/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
286/// function and handle the differences between the Rust calling convention and
287/// Enzyme.
288/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
289// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
290// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
291pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
292    builder: &mut Builder<'_, 'll, 'tcx>,
293    cx: &SimpleCx<'ll>,
294    fn_to_diff: &'ll Value,
295    outer_name: &str,
296    ret_ty: &'ll Type,
297    fn_args: &[&'ll Value],
298    attrs: AutoDiffAttrs,
299    dest: PlaceRef<'tcx, &'ll Value>,
300    fnc_tree: FncTree,
301) {
302    // We have to pick the name depending on whether we want forward or reverse mode autodiff.
303    let mut ad_name: String = match attrs.mode {
304        DiffMode::Forward => "__enzyme_fwddiff",
305        DiffMode::Reverse => "__enzyme_autodiff",
306        _ => panic!("logic bug in autodiff, unrecognized mode"),
307    }
308    .to_string();
309
310    // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
311    // functions. Unwrap will only panic, if LLVM gave us an invalid string.
312    ad_name.push_str(outer_name);
313
314    // Let us assume the user wrote the following function square:
315    //
316    // ```llvm
317    // define double @square(double %x) {
318    // entry:
319    //  %0 = fmul double %x, %x
320    //  ret double %0
321    // }
322    //
323    // define double @dsquare(double %x) {
324    //  return 0.0;
325    // }
326    // ```
327    //
328    // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
329    // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
330    // Again, the arguments to all functions are slightly simplified.
331    // ```llvm
332    // declare double @__enzyme_autodiff_square(...)
333    //
334    // define double @dsquare(double %x) {
335    // entry:
336    //   %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
337    //   ret double %0
338    // }
339    // ```
340    let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, TRUE) };
341
342    // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
343    // think a bit more about what should go here.
344    let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) };
345    let ad_fn = declare_simple_fn(
346        cx,
347        &ad_name,
348        llvm::CallConv::try_from(cc).expect("invalid callconv"),
349        llvm::UnnamedAddr::No,
350        llvm::Visibility::Default,
351        enzyme_ty,
352    );
353
354    let num_args = llvm::LLVMCountParams(&fn_to_diff);
355    let mut args = Vec::with_capacity(num_args as usize + 1);
356    args.push(fn_to_diff);
357
358    let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
359    if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
360        args.push(global_primal_ret);
361    }
362    if attrs.width > 1 {
363        let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
364        args.push(global_width);
365        args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
366    }
367
368    match_args_from_caller_to_enzyme(
369        &cx,
370        builder,
371        attrs.width,
372        &mut args,
373        &attrs.input_activity,
374        fn_args,
375    );
376
377    if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
378        crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
379    }
380
381    let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
382
383    let fn_ret_ty = builder.cx.val_ty(call);
384    if fn_ret_ty != builder.cx.type_void() && fn_ret_ty != builder.cx.type_struct(&[], false) {
385        // If we return void or an empty struct, then our caller (due to how we generated it)
386        // does not expect a return value. As such, we have no pointer (or place) into which
387        // we could store our value, and would store into an undef, which would cause UB.
388        // As such, we just ignore the return value in those cases.
389        builder.store_to_place(call, dest.val);
390    }
391}