1use std::ptr;
23use rustc_ast::expand::autodiff_attrs::{AutoDiffAttrs, DiffActivity, DiffMode};
4use rustc_ast::expand::typetree::FncTree;
5use rustc_codegen_ssa::common::TypeKind;
6use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7use rustc_middle::ty::{Instance, PseudoCanonicalInput, TyCtxt, TypingEnv};
8use rustc_middle::{bug, ty};
9use rustc_target::callconv::PassMode;
10use tracing::debug;
1112use crate::builder::{Builder, PlaceRef, UNNAMED};
13use crate::context::SimpleCx;
14use crate::declare::declare_simple_fn;
15use crate::llvm::{self, TRUE, Type, Value};
1617pub(crate) fn adjust_activity_to_abi<'tcx>(
18 tcx: TyCtxt<'tcx>,
19 instance: Instance<'tcx>,
20 typing_env: TypingEnv<'tcx>,
21 da: &mut Vec<DiffActivity>,
22) {
23let fn_ty = instance.ty(tcx, typing_env);
2425if !#[allow(non_exhaustive_omitted_patterns)] match fn_ty.kind() {
ty::FnDef(..) => true,
_ => false,
}matches!(fn_ty.kind(), ty::FnDef(..)) {
26::rustc_middle::util::bug::bug_fmt(format_args!("expected fn def for autodiff, got {0:?}",
fn_ty));bug!("expected fn def for autodiff, got {:?}", fn_ty);
27 }
2829// We don't actually pass the types back into the type system.
30 // All we do is decide how to handle the arguments.
31let sig = fn_ty.fn_sig(tcx).skip_binder();
3233// FIXME(Sa4dUs): pass proper varargs once we have support for differentiating variadic functions
34let Ok(fn_abi) =
35tcx.fn_abi_of_instance(typing_env.as_query_input((instance, ty::List::empty())))
36else {
37::rustc_middle::util::bug::bug_fmt(format_args!("failed to get fn_abi of instance with empty varargs"));bug!("failed to get fn_abi of instance with empty varargs");
38 };
3940let mut new_activities = ::alloc::vec::Vec::new()vec![];
41let mut new_positions = ::alloc::vec::Vec::new()vec![];
42let mut del_activities = 0;
43for (i, ty) in sig.inputs().iter().enumerate() {
44if let Some(inner_ty) = ty.builtin_deref(true) {
45if inner_ty.is_slice() {
46// Now we need to figure out the size of each slice element in memory to allow
47 // safety checks and usability improvements in the backend.
48let sty = match inner_ty.builtin_index() {
49Some(sty) => sty,
50None => {
51{ ::core::panicking::panic_fmt(format_args!("slice element type unknown")); };panic!("slice element type unknown");
52 }
53 };
54let pci = PseudoCanonicalInput {
55 typing_env: TypingEnv::fully_monomorphized(),
56 value: sty,
57 };
5859let layout = tcx.layout_of(pci);
60let elem_size = match layout {
61Ok(layout) => layout.size,
62Err(_) => {
63::rustc_middle::util::bug::bug_fmt(format_args!("autodiff failed to compute slice element size"));bug!("autodiff failed to compute slice element size");
64 }
65 };
66let elem_size: u32 = elem_size.bytes() as u32;
6768// We know that the length will be passed as extra arg.
69if !da.is_empty() {
70// We are looking at a slice. The length of that slice will become an
71 // extra integer on llvm level. Integers are always const.
72 // However, if the slice get's duplicated, we want to know to later check the
73 // size. So we mark the new size argument as FakeActivitySize.
74 // There is one FakeActivitySize per slice, so for convenience we store the
75 // slice element size in bytes in it. We will use the size in the backend.
76let activity = match da[i] {
77 DiffActivity::DualOnly
78 | DiffActivity::Dual
79 | DiffActivity::Dualv
80 | DiffActivity::DuplicatedOnly
81 | DiffActivity::Duplicated => {
82 DiffActivity::FakeActivitySize(Some(elem_size))
83 }
84 DiffActivity::Const => DiffActivity::Const,
85_ => ::rustc_middle::util::bug::bug_fmt(format_args!("unexpected activity for ptr/ref"))bug!("unexpected activity for ptr/ref"),
86 };
87 new_activities.push(activity);
88 new_positions.push(i + 1);
89 }
9091continue;
92 }
93 }
9495let pci = PseudoCanonicalInput { typing_env: TypingEnv::fully_monomorphized(), value: *ty };
9697let layout = match tcx.layout_of(pci) {
98Ok(layout) => layout.layout,
99Err(_) => {
100::rustc_middle::util::bug::bug_fmt(format_args!("failed to compute layout for type {0:?}",
ty));bug!("failed to compute layout for type {:?}", ty);
101 }
102 };
103104let pass_mode = &fn_abi.args[i].mode;
105106// For ZST, just ignore and don't add its activity, as this arg won't be present
107 // in the LLVM passed to Enzyme.
108 // Some targets pass ZST indirectly in the C ABI, in that case, handle it as a normal arg
109 // FIXME(Sa4dUs): Enforce ZST corresponding diff activity be `Const`
110if *pass_mode == PassMode::Ignore {
111 del_activities += 1;
112 da.remove(i);
113 }
114115// If the argument is lowered as a `ScalarPair`, we need to duplicate its activity.
116 // Otherwise, the number of activities won't match the number of LLVM arguments and
117 // this will lead to errors when verifying the Enzyme call.
118if let rustc_abi::BackendRepr::ScalarPair(_, _) = layout.backend_repr() {
119 new_activities.push(da[i].clone());
120 new_positions.push(i + 1 - del_activities);
121 }
122 }
123// now add the extra activities coming from slices
124 // Reverse order to not invalidate the indices
125for _ in 0..new_activities.len() {
126let pos = new_positions.pop().unwrap();
127let activity = new_activities.pop().unwrap();
128 da.insert(pos, activity);
129 }
130}
131132// When we call the `__enzyme_autodiff` or `__enzyme_fwddiff` function, we need to pass all the
133// original inputs, as well as metadata and the additional shadow arguments.
134// This function matches the arguments from the outer function to the inner enzyme call.
135//
136// This function also considers that Rust level arguments not always match the llvm-ir level
137// arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
138// llvm-ir level. The number of activities matches the number of Rust level arguments, so we
139// need to match those.
140// FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
141// using iterators and peek()?
142fn match_args_from_caller_to_enzyme<'ll, 'tcx>(
143 cx: &SimpleCx<'ll>,
144 builder: &mut Builder<'_, 'll, 'tcx>,
145 width: u32,
146 args: &mut Vec<&'ll Value>,
147 inputs: &[DiffActivity],
148 outer_args: &[&'ll Value],
149) {
150{
use ::tracing::__macro_support::Callsite as _;
static __CALLSITE: ::tracing::callsite::DefaultCallsite =
{
static META: ::tracing::Metadata<'static> =
{
::tracing_core::metadata::Metadata::new("event compiler/rustc_codegen_llvm/src/builder/autodiff.rs:150",
"rustc_codegen_llvm::builder::autodiff",
::tracing::Level::DEBUG,
::tracing_core::__macro_support::Option::Some("compiler/rustc_codegen_llvm/src/builder/autodiff.rs"),
::tracing_core::__macro_support::Option::Some(150u32),
::tracing_core::__macro_support::Option::Some("rustc_codegen_llvm::builder::autodiff"),
::tracing_core::field::FieldSet::new(&["message"],
::tracing_core::callsite::Identifier(&__CALLSITE)),
::tracing::metadata::Kind::EVENT)
};
::tracing::callsite::DefaultCallsite::new(&META)
};
let enabled =
::tracing::Level::DEBUG <= ::tracing::level_filters::STATIC_MAX_LEVEL
&&
::tracing::Level::DEBUG <=
::tracing::level_filters::LevelFilter::current() &&
{
let interest = __CALLSITE.interest();
!interest.is_never() &&
::tracing::__macro_support::__is_enabled(__CALLSITE.metadata(),
interest)
};
if enabled {
(|value_set: ::tracing::field::ValueSet|
{
let meta = __CALLSITE.metadata();
::tracing::Event::dispatch(meta, &value_set);
;
})({
#[allow(unused_imports)]
use ::tracing::field::{debug, display, Value};
let mut iter = __CALLSITE.metadata().fields().iter();
__CALLSITE.metadata().fields().value_set(&[(&::tracing::__macro_support::Iterator::next(&mut iter).expect("FieldSet corrupted (this is a bug)"),
::tracing::__macro_support::Option::Some(&format_args!("matching autodiff arguments")
as &dyn Value))])
});
} else { ; }
};debug!("matching autodiff arguments");
151// We now handle the issue that Rust level arguments not always match the llvm-ir level
152 // arguments. A slice, `&[f32]`, for example, is represented as a pointer and a length on
153 // llvm-ir level. The number of activities matches the number of Rust level arguments, so we
154 // need to match those.
155 // FIXME(ZuseZ4): This logic is a bit more complicated than it should be, can we simplify it
156 // using iterators and peek()?
157let mut outer_pos: usize = 0;
158let mut activity_pos = 0;
159160// We used to use llvm's metadata to instruct enzyme how to differentiate a function.
161 // In debug mode we would use incremental compilation which caused the metadata to be
162 // dropped. This is prevented by now using named globals, which are also understood
163 // by Enzyme.
164let global_const = cx.declare_global("enzyme_const", cx.type_ptr());
165let global_out = cx.declare_global("enzyme_out", cx.type_ptr());
166let global_dup = cx.declare_global("enzyme_dup", cx.type_ptr());
167let global_dupv = cx.declare_global("enzyme_dupv", cx.type_ptr());
168let global_dupnoneed = cx.declare_global("enzyme_dupnoneed", cx.type_ptr());
169let global_dupnoneedv = cx.declare_global("enzyme_dupnoneedv", cx.type_ptr());
170171while activity_pos < inputs.len() {
172let diff_activity = inputs[activity_pos as usize];
173// Duplicated arguments received a shadow argument, into which enzyme will write the
174 // gradient.
175let (activity, duplicated): (&Value, bool) = match diff_activity {
176 DiffActivity::None => { ::core::panicking::panic_fmt(format_args!("not a valid input activity")); }panic!("not a valid input activity"),
177 DiffActivity::Const => (global_const, false),
178 DiffActivity::Active => (global_out, false),
179 DiffActivity::ActiveOnly => (global_out, false),
180 DiffActivity::Dual => (global_dup, true),
181 DiffActivity::Dualv => (global_dupv, true),
182 DiffActivity::DualOnly => (global_dupnoneed, true),
183 DiffActivity::DualvOnly => (global_dupnoneedv, true),
184 DiffActivity::Duplicated => (global_dup, true),
185 DiffActivity::DuplicatedOnly => (global_dupnoneed, true),
186 DiffActivity::FakeActivitySize(_) => (global_const, false),
187 };
188let outer_arg = outer_args[outer_pos];
189 args.push(activity);
190if #[allow(non_exhaustive_omitted_patterns)] match diff_activity {
DiffActivity::Dualv => true,
_ => false,
}matches!(diff_activity, DiffActivity::Dualv) {
191let next_outer_arg = outer_args[outer_pos + 1];
192let elem_bytes_size: u64 = match inputs[activity_pos + 1] {
193 DiffActivity::FakeActivitySize(Some(s)) => s.into(),
194_ => ::rustc_middle::util::bug::bug_fmt(format_args!("incorrect Dualv handling recognized."))bug!("incorrect Dualv handling recognized."),
195 };
196// stride: sizeof(T) * n_elems.
197 // n_elems is the next integer.
198 // Now we multiply `4 * next_outer_arg` to get the stride.
199let mul = unsafe {
200 llvm::LLVMBuildMul(
201 builder.llbuilder,
202 cx.get_const_int(cx.type_i64(), elem_bytes_size),
203 next_outer_arg,
204 UNNAMED,
205 )
206 };
207 args.push(mul);
208 }
209 args.push(outer_arg);
210if duplicated {
211// We know that duplicated args by construction have a following argument,
212 // so this can not be out of bounds.
213let next_outer_arg = outer_args[outer_pos + 1];
214let next_outer_ty = cx.val_ty(next_outer_arg);
215// FIXME(ZuseZ4): We should add support for Vec here too, but it's less urgent since
216 // vectors behind references (&Vec<T>) are already supported. Users can not pass a
217 // Vec by value for reverse mode, so this would only help forward mode autodiff.
218let slice = {
219if activity_pos + 1 >= inputs.len() {
220// If there is no arg following our ptr, it also can't be a slice,
221 // since that would lead to a ptr, int pair.
222false
223} else {
224let next_activity = inputs[activity_pos + 1];
225// We analyze the MIR types and add this dummy activity if we visit a slice.
226#[allow(non_exhaustive_omitted_patterns)] match next_activity {
DiffActivity::FakeActivitySize(_) => true,
_ => false,
}matches!(next_activity, DiffActivity::FakeActivitySize(_))227 }
228 };
229if slice {
230// A duplicated slice will have the following two outer_fn arguments:
231 // (..., ptr1, int1, ptr2, int2, ...). We add the following llvm-ir to our __enzyme call:
232 // (..., metadata! enzyme_dup, ptr, ptr, int1, ...).
233 // FIXME(ZuseZ4): We will upstream a safety check later which asserts that
234 // int2 >= int1, which means the shadow vector is large enough to store the gradient.
235match (&cx.type_kind(next_outer_ty), &TypeKind::Integer) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(kind, &*left_val, &*right_val,
::core::option::Option::None);
}
}
};assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Integer);
236237let iterations =
238if #[allow(non_exhaustive_omitted_patterns)] match diff_activity {
DiffActivity::Dualv => true,
_ => false,
}matches!(diff_activity, DiffActivity::Dualv) { 1 } else { width as usize };
239240for i in 0..iterations {
241let next_outer_arg2 = outer_args[outer_pos + 2 * (i + 1)];
242let next_outer_ty2 = cx.val_ty(next_outer_arg2);
243match (&cx.type_kind(next_outer_ty2), &TypeKind::Pointer) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(kind, &*left_val, &*right_val,
::core::option::Option::None);
}
}
};assert_eq!(cx.type_kind(next_outer_ty2), TypeKind::Pointer);
244let next_outer_arg3 = outer_args[outer_pos + 2 * (i + 1) + 1];
245let next_outer_ty3 = cx.val_ty(next_outer_arg3);
246match (&cx.type_kind(next_outer_ty3), &TypeKind::Integer) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(kind, &*left_val, &*right_val,
::core::option::Option::None);
}
}
};assert_eq!(cx.type_kind(next_outer_ty3), TypeKind::Integer);
247 args.push(next_outer_arg2);
248 }
249 args.push(global_const);
250 args.push(next_outer_arg);
251 outer_pos += 2 + 2 * iterations;
252 activity_pos += 2;
253 } else {
254// A duplicated pointer will have the following two outer_fn arguments:
255 // (..., ptr, ptr, ...). We add the following llvm-ir to our __enzyme call:
256 // (..., metadata! enzyme_dup, ptr, ptr, ...).
257if #[allow(non_exhaustive_omitted_patterns)] match diff_activity {
DiffActivity::Duplicated | DiffActivity::DuplicatedOnly => true,
_ => false,
}matches!(diff_activity, DiffActivity::Duplicated | DiffActivity::DuplicatedOnly)258 {
259match (&cx.type_kind(next_outer_ty), &TypeKind::Pointer) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(kind, &*left_val, &*right_val,
::core::option::Option::None);
}
}
};assert_eq!(cx.type_kind(next_outer_ty), TypeKind::Pointer);
260 }
261// In the case of Dual we don't have assumptions, e.g. f32 would be valid.
262args.push(next_outer_arg);
263 outer_pos += 2;
264 activity_pos += 1;
265266// Now, if width > 1, we need to account for that
267for _ in 1..width {
268let next_outer_arg = outer_args[outer_pos];
269 args.push(next_outer_arg);
270 outer_pos += 1;
271 }
272 }
273 } else {
274// We do not differentiate with resprect to this argument.
275 // We already added the metadata and argument above, so just increase the counters.
276outer_pos += 1;
277 activity_pos += 1;
278 }
279 }
280}
281282/// When differentiating `fn_to_diff`, take a `outer_fn` and generate another
283/// function with expected naming and calling conventions[^1] which will be
284/// discovered by the enzyme LLVM pass and its body populated with the differentiated
285/// `fn_to_diff`. `outer_fn` is then modified to have a call to the generated
286/// function and handle the differences between the Rust calling convention and
287/// Enzyme.
288/// [^1]: <https://enzyme.mit.edu/getting_started/CallingConvention/>
289// FIXME(ZuseZ4): `outer_fn` should include upstream safety checks to
290// cover some assumptions of enzyme/autodiff, which could lead to UB otherwise.
291pub(crate) fn generate_enzyme_call<'ll, 'tcx>(
292 builder: &mut Builder<'_, 'll, 'tcx>,
293 cx: &SimpleCx<'ll>,
294 fn_to_diff: &'ll Value,
295 outer_name: &str,
296 ret_ty: &'ll Type,
297 fn_args: &[&'ll Value],
298 attrs: AutoDiffAttrs,
299 dest: PlaceRef<'tcx, &'ll Value>,
300 fnc_tree: FncTree,
301) {
302// We have to pick the name depending on whether we want forward or reverse mode autodiff.
303let mut ad_name: String = match attrs.mode {
304 DiffMode::Forward => "__enzyme_fwddiff",
305 DiffMode::Reverse => "__enzyme_autodiff",
306_ => {
::core::panicking::panic_fmt(format_args!("logic bug in autodiff, unrecognized mode"));
}panic!("logic bug in autodiff, unrecognized mode"),
307 }
308 .to_string();
309310// add outer_name to ad_name to make it unique, in case users apply autodiff to multiple
311 // functions. Unwrap will only panic, if LLVM gave us an invalid string.
312ad_name.push_str(outer_name);
313314// Let us assume the user wrote the following function square:
315 //
316 // ```llvm
317 // define double @square(double %x) {
318 // entry:
319 // %0 = fmul double %x, %x
320 // ret double %0
321 // }
322 //
323 // define double @dsquare(double %x) {
324 // return 0.0;
325 // }
326 // ```
327 //
328 // so our `outer_fn` will be `dsquare`. The unsafe code section below now removes the placeholder
329 // code and inserts an autodiff call. We also add a declaration for the __enzyme_autodiff call.
330 // Again, the arguments to all functions are slightly simplified.
331 // ```llvm
332 // declare double @__enzyme_autodiff_square(...)
333 //
334 // define double @dsquare(double %x) {
335 // entry:
336 // %0 = tail call double (...) @__enzyme_autodiff_square(double (double)* nonnull @square, double %x)
337 // ret double %0
338 // }
339 // ```
340let enzyme_ty = unsafe { llvm::LLVMFunctionType(ret_ty, ptr::null(), 0, TRUE) };
341342// FIXME(ZuseZ4): the CC/Addr/Vis values are best effort guesses, we should look at tests and
343 // think a bit more about what should go here.
344let cc = unsafe { llvm::LLVMGetFunctionCallConv(fn_to_diff) };
345let ad_fn = declare_simple_fn(
346cx,
347&ad_name,
348 llvm::CallConv::try_from(cc).expect("invalid callconv"),
349 llvm::UnnamedAddr::No,
350 llvm::Visibility::Default,
351enzyme_ty,
352 );
353354let num_args = llvm::LLVMCountParams(&fn_to_diff);
355let mut args = Vec::with_capacity(num_argsas usize + 1);
356args.push(fn_to_diff);
357358let global_primal_ret = cx.declare_global("enzyme_primal_return", cx.type_ptr());
359if #[allow(non_exhaustive_omitted_patterns)] match attrs.ret_activity {
DiffActivity::Dual | DiffActivity::Active => true,
_ => false,
}matches!(attrs.ret_activity, DiffActivity::Dual | DiffActivity::Active) {
360args.push(global_primal_ret);
361 }
362if attrs.width > 1 {
363let global_width = cx.declare_global("enzyme_width", cx.type_ptr());
364args.push(global_width);
365args.push(cx.get_const_int(cx.type_i64(), attrs.width as u64));
366 }
367368match_args_from_caller_to_enzyme(
369&cx,
370builder,
371attrs.width,
372&mut args,
373&attrs.input_activity,
374fn_args,
375 );
376377if !fnc_tree.args.is_empty() || !fnc_tree.ret.0.is_empty() {
378crate::typetree::add_tt(cx.llmod, cx.llcx, fn_to_diff, fnc_tree);
379 }
380381let call = builder.call(enzyme_ty, None, None, ad_fn, &args, None, None);
382383let fn_ret_ty = builder.cx.val_ty(call);
384if fn_ret_ty != builder.cx.type_void() && fn_ret_ty != builder.cx.type_struct(&[], false) {
385// If we return void or an empty struct, then our caller (due to how we generated it)
386 // does not expect a return value. As such, we have no pointer (or place) into which
387 // we could store our value, and would store into an undef, which would cause UB.
388 // As such, we just ignore the return value in those cases.
389builder.store_to_place(call, dest.val);
390 }
391}