rustc_codegen_llvm/builder/autodiff.rs
1use std::ptr;
2
3use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
4use rustc_ast::expand::typetree::FncTree;
5use rustc_codegen_ssa::common::TypeKind;
6use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
8use rustc_middle::{bug, ty};
9use rustc_target::callconv::PassMode;
10use tracing::debug;
11
12use crate::builder::{Builder, PlaceRef, UNNAMED};
13use crate::context::SimpleCx;
14use crate::declare::declare_simple_fn;
15use crate::llvm::{self, TRUE, Type, Value};
16
17pub(crate) fn adjust_activity_to_abi<'tcx>(
18 tcx: TyCtxt<'tcx>,
19 instance: Instance<'tcx>,
20 typing_env: TypingEnv<'tcx>,
21 da: &mut Vec<DiffActivity>,
22) {
23 let fn_ty = instance.ty(tcx, typing_env);
24
25 if !matches!(fn_ty.kind(), ty::FnDef(..)) {
26 bug!("expected fn def for autodiff, got {:?}", fn_ty);
27 }
28
29 // We don't actually pass the types back into the type system.
30 // All we do is decide how to handle the arguments.
31 let sig = fn_ty.fn_sig(tcx).skip_binder();
32
33 // FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
34 let Ok(fn_abi) =
35 tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
36 else {
37 bug!("failed to get fn_abi of instance with empty varargs");
38 };
39
40 let mut new_activities = vec![];
41 let mut new_positions = vec![];
42 let mut del_activities = 0;
43 for (i, ty) in sig.inputs().iter().enumerate() {
44 if let Some(inner_ty) = ty.builtin_deref(true) {
45 if inner_ty.is_slice() {
46 // Now we need to figure out the size of each slice element in memory to allow
47 // safety checks and usability improvements in the backend.
48 let sty = match inner_ty.builtin_index() {
49 Some(sty) => sty,
50 None => {
51 panic!("slice element type unknown");
52 }
53 };
54 let pci = PseudoCanonicalInput {
55 typing_env: TypingEnv::fully_monomorphized(),
56 value: sty,
57 };
58
59 let layout = tcx.layout_of(pci);
60 let elem_size = match layout {
61 Ok(layout) => layout.size,
62 Err(_) => {
63 bug!("autodiff failed to compute slice element size");
64 }
65 };
66 let elem_size: u32 = elem_size.bytes() as u32;
67
68 // We know that the length will be passed as extra arg.
69 if !da.is_empty() {
70 // We are looking at a slice. The length of that slice will become an
71 // extra integer on llvm level. Integers are always const.
72 // However, if the slice get's duplicated, we want to know to later check the
73 // size. So we mark the new size argument as FakeActivitySize.
74 // There is one FakeActivitySize per slice, so for convenience we store the
75 // slice element size in bytes in it. We will use the size in the backend.
76 let activity = match da[i] {
77 DiffActivity::DualOnly
78 | DiffActivity::Dual
79 | DiffActivity::Dualv
80 | DiffActivity::DuplicatedOnly
81 | DiffActivity::Duplicated => {
82 DiffActivity::FakeActivitySize(Some(elem_size))
83 }
84 DiffActivity::Const => DiffActivity::Const,
85 _ => bug!("unexpected activity for ptr/ref"),
86 };
87 new_activities.push(activity);
88 new_positions.push(i + 1);
89 }
90
91 continue;
92 }
93 }
94
95 let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
96
97 let layout = match tcx.layout_of(pci) {
98 Ok(layout) => layout.layout,
99 Err(_) => {
100 bug!("failed to compute layout for type {:?}", ty);
101 }
102 };
103
104 let pass_mode = &fn_abi.args[i].mode;
105
106 // For ZST, just ignore and don't add its activity, as this arg won't be present
107 // in the LLVM passed to Enzyme.
108 // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
109 // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
110 if *pass_mode == PassMode::Ignore {
111 del_activities += 1;
112 da.remove(i);
113 }
114
115 // If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
116 // Otherwise, the number of activities won't match the number of LLVM arguments and
117 // this will lead to errors when verifying the Enzyme call.
118 if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
119 new_activities.push(da[i].clone());
120 new_positions.push(i + 1 - del_activities);
121 }
122 }
123 // now add the extra activities coming from slices
124 // Reverse order to not invalidate the indices
125 for _ in 0..new_activities.len() {
126 let pos = new_positions.pop().unwrap();
127 let activity = new_activities.pop().unwrap();
128 da.insert(pos, activity);
129 }
130}
131
132// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
133// original inputs, as well as metadata and the additional shadow arguments.
134// This function matches the arguments from the outer function to the inner enzyme call.
135//
136// This function also considers that Rust level arguments not always match the llvm-ir level
137// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
138// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
139// need to match those.
140// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
141// using iterators and peek()?
142fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
143 cx: &SimpleCx<'ll>,
144 builder: &mut Builder<'_, 'll, 'tcx>,
145 width: u32,
146 args: &mut Vec<&'ll Value>,
147 inputs: &[DiffActivity],
148 outer_args: &[&'ll Value],
149) {
150 debug!("matching autodiff arguments");
151 // We now handle the issue that Rust level arguments not always match the llvm-ir level
152 // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
153 // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
154 // need to match those.
155 // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
156 // using iterators and peek()?
157 let mut outer_pos: usize = 0;
158 let mut activity_pos = 0;
159
160 // We used to use llvm's metadata to instruct enzyme how to differentiate a function.
161 // In debug mode we would use incremental compilation which caused the metadata to be
162 // dropped. This is prevented by now using named globals, which are also understood
163 // by Enzyme.
164 let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
165 let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
166 let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
167 let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
168 let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
169 let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());
170
171 while activity_pos < inputs.len() {
172 let diff_activity = inputs[activity_pos as usize];
173 // Duplicated arguments received a shadow argument, into which enzyme will write the
174 // gradient.
175 let (activity, duplicated): (&Value, bool) = match diff_activity {
176 DiffActivity::None => panic!("not a valid input activity"),
177 DiffActivity::Const => (global_const, false),
178 DiffActivity::Active => (global_out, false),
179 DiffActivity::ActiveOnly => (global_out, false),
180 DiffActivity::Dual => (global_dup, true),
181 DiffActivity::Dualv => (global_dupv, true),
182 DiffActivity::DualOnly => (global_dupnoneed, true),
183 DiffActivity::DualvOnly => (global_dupnoneedv, true),
184 DiffActivity::Duplicated => (global_dup, true),
185 DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
186 DiffActivity::FakeActivitySize(_) => (global_const, false),
187 };
188 let outer_arg = outer_args[outer_pos];
189 args.push(activity);
190 if matches!(diff_activity, DiffActivity::Dualv) {
191 let next_outer_arg = outer_args[outer_pos + 1];
192 let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
193 DiffActivity::FakeActivitySize(Some(s)) => s.into(),
194 _ => bug!("incorrect Dualv handling recognized."),
195 };
196 // stride: sizeof(T) * n_elems.
197 // n_elems is the next integer.
198 // Now we multiply `4 * next_outer_arg` to get the stride.
199 let mul = unsafe {
200 llvm::LLVMBuildMul(
201 builder.llbuilder,
202 cx.get_const_int(cx.type_i64(), elem_bytes_size),
203 next_outer_arg,
204 UNNAMED,
205 )
206 };
207 args.push(mul);
208 }
209 args.push(outer_arg);
210 if duplicated {
211 // We know that duplicated args by construction have a following argument,
212 // so this can not be out of bounds.
213 let next_outer_arg = outer_args[outer_pos + 1];
214 let next_outer_ty = cx.val_ty(next_outer_arg);
215 // FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
216 // vectors behind references (&Vec<T>) are already supported. Users can not pass a
217 // Vec by value for reverse mode, so this would only help forward mode autodiff.
218 let slice = {
219 if activity_pos + 1 >= inputs.len() {
220 // If there is no arg following our ptr, it also can't be a slice,
221 // since that would lead to a ptr, int pair.
222 false
223 } else {
224 let next_activity = inputs[activity_pos + 1];
225 // We analyze the MIR types and add this dummy activity if we visit a slice.
226 matches!(next_activity, DiffActivity::FakeActivitySize(_))
227 }
228 };
229 if slice {
230 // A duplicated slice will have the following two outer_fn arguments:
231 // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
232 // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
233 // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
234 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
235 assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
236
237 let iterations =
238 if matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
239
240 for i in 0..iterations {
241 let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
242 let next_outer_ty2 = cx.val_ty(next_outer_arg2);
243 assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
244 let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
245 let next_outer_ty3 = cx.val_ty(next_outer_arg3);
246 assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
247 args.push(next_outer_arg2);
248 }
249 args.push(global_const);
250 args.push(next_outer_arg);
251 outer_pos += 2 + 2 * iterations;
252 activity_pos += 2;
253 } else {
254 // A duplicated pointer will have the following two outer_fn arguments:
255 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
256 // (..., metadata! enzyme_dup, ptr, ptr, ...).
257 if matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)
258 {
259 assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
260 }
261 // In the case of Dual we don't have assumptions, e.g. f32 would be valid.
262 args.push(next_outer_arg);
263 outer_pos += 2;
264 activity_pos += 1;
265
266 // Now, if width > 1, we need to account for that
267 for _ in 1..width {
268 let next_outer_arg = outer_args[outer_pos];
269 args.push(next_outer_arg);
270 outer_pos += 1;
271 }
272 }
273 } else {
274 // We do not differentiate with resprect to this argument.
275 // We already added the metadata and argument above, so just increase the counters.
276 outer_pos += 1;
277 activity_pos += 1;
278 }
279 }
280}
281
282/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
283/// function with expected naming and calling conventions[^1] which will be
284/// discovered by the enzyme LLVM pass and its body populated with the differentiated
285/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
286/// function and handle the differences between the Rust calling convention and
287/// Enzyme.
288/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
289// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
290// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
291pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
292 builder: &mut Builder<'_, 'll, 'tcx>,
293 cx: &SimpleCx<'ll>,
294 fn_to_diff: &'ll Value,
295 outer_name: &str,
296 ret_ty: &'ll Type,
297 fn_args: &[&'ll Value],
298 attrs: AutoDiffAttrs,
299 dest: PlaceRef<'tcx, &'ll Value>,
300 fnc_tree: FncTree,
301) {
302 // We have to pick the name depending on whether we want forward or reverse mode autodiff.
303 let mut ad_name: String = match attrs.mode {
304 DiffMode::Forward => "__enzyme_fwddiff",
305 DiffMode::Reverse => "__enzyme_autodiff",
306 _ => panic!("logic bug in autodiff, unrecognized mode"),
307 }
308 .to_string();
309
310 // add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
311 // functions. Unwrap will only panic, if LLVM gave us an invalid string.
312 ad_name.push_str(outer_name);
313
314 // Let us assume the user wrote the following function square:
315 //
316 // ```llvm
317 // define double @square(double %x) {
318 // entry:
319 // %0 = fmul double %x, %x
320 // ret double %0
321 // }
322 //
323 // define double @dsquare(double %x) {
324 // return 0.0;
325 // }
326 // ```
327 //
328 // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
329 // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
330 // Again, the arguments to all functions are slightly simplified.
331 // ```llvm
332 // declare double @__enzyme_autodiff_square(...)
333 //
334 // define double @dsquare(double %x) {
335 // entry:
336 // %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
337 // ret double %0
338 // }
339 // ```
340 let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, TRUE) };
341
342 // FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
343 // think a bit more about what should go here.
344 let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) };
345 let ad_fn = declare_simple_fn(
346 cx,
347 &ad_name,
348 llvm::CallConv::try_from(cc).expect("invalid callconv"),
349 llvm::UnnamedAddr::No,
350 llvm::Visibility::Default,
351 enzyme_ty,
352 );
353
354 let num_args = llvm::LLVMCountParams(&fn_to_diff);
355 let mut args = Vec::with_capacity(num_args as usize + 1);
356 args.push(fn_to_diff);
357
358 let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
359 if matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
360 args.push(global_primal_ret);
361 }
362 if attrs.width > 1 {
363 let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
364 args.push(global_width);
365 args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
366 }
367
368 match_args_from_caller_to_enzyme(
369 &cx,
370 builder,
371 attrs.width,
372 &mut args,
373 &attrs.input_activity,
374 fn_args,
375 );
376
377 if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
378 crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
379 }
380
381 let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
382
383 let fn_ret_ty = builder.cx.val_ty(call);
384 if fn_ret_ty != builder.cx.type_void() && fn_ret_ty != builder.cx.type_struct(&[], false) {
385 // If we return void or an empty struct, then our caller (due to how we generated it)
386 // does not expect a return value. As such, we have no pointer (or place) into which
387 // we could store our value, and would store into an undef, which would cause UB.
388 // As such, we just ignore the return value in those cases.
389 builder.store_to_place(call, dest.val);
390 }
391}