rustc_ast/expand/autodiff_attrs.rs
1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use crate::expand::typetree::TypeTree;
10use crate::expand::{Decodable, Encodable, HashStable_Generic};
11use crate::{Ty, TyKind};
12
13/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
14/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
15/// are a hack to support higher order derivatives. We need to compute first order derivatives
16/// before we compute second order derivatives, otherwise we would differentiate our placeholder
17/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
18/// as it's already done in the C++ and Julia frontend of Enzyme.
19///
20/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
21/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
22#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
23pub enum DiffMode {
24 /// No autodiff is applied (used during error handling).
25 Error,
26 /// The primal function which we will differentiate.
27 Source,
28 /// The target function, to be created using forward mode AD.
29 Forward,
30 /// The target function, to be created using reverse mode AD.
31 Reverse,
32}
33
34/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
35/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
36/// we add to the previous shadow value. To not surprise users, we picked different names.
37/// Dual numbers is also a quite well known name for forward mode AD types.
38#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
39pub enum DiffActivity {
40 /// Implicit or Explicit () return type, so a special case of Const.
41 None,
42 /// Don't compute derivatives with respect to this input/output.
43 Const,
44 /// Reverse Mode, Compute derivatives for this scalar input/output.
45 Active,
46 /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
47 /// the original return value.
48 ActiveOnly,
49 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
50 /// with it.
51 Dual,
52 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53 /// with it. It expects the shadow argument to be `width` times larger than the original
54 /// input/output.
55 Dualv,
56 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
57 /// with it. Drop the code which updates the original input/output for maximum performance.
58 DualOnly,
59 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60 /// with it. Drop the code which updates the original input/output for maximum performance.
61 /// It expects the shadow argument to be `width` times larger than the original input/output.
62 DualvOnly,
63 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
64 Duplicated,
65 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
66 /// Drop the code which updates the original input for maximum performance.
67 DuplicatedOnly,
68 /// All Integers must be Const, but these are used to mark the integer which represents the
69 /// length of a slice/vec. This is used for safety checks on slices.
70 /// The integer (if given) specifies the size of the slice element in bytes.
71 FakeActivitySize(Option<u32>),
72}
73
74impl DiffActivity {
75 pub fn is_dual_or_const(&self) -> bool {
76 use DiffActivity::*;
77 matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
78 }
79}
80/// We generate one of these structs for each `#[autodiff(...)]` attribute.
81#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
82pub struct AutoDiffItem {
83 /// The name of the function getting differentiated
84 pub source: String,
85 /// The name of the function being generated
86 pub target: String,
87 pub attrs: AutoDiffAttrs,
88 pub inputs: Vec<TypeTree>,
89 pub output: TypeTree,
90}
91
92#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
93pub struct AutoDiffAttrs {
94 /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
95 /// e.g. in the [JAX
96 /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
97 pub mode: DiffMode,
98 /// A user-provided, batching width. If not given, we will default to 1 (no batching).
99 /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
100 /// - Calling the function 50 times with a batch size of 2
101 /// - Calling the function 25 times with a batch size of 4,
102 /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
103 /// cache locality, better re-usal of primal values, and other optimizations.
104 /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
105 /// times, so this massively increases code size. As such, values like 1024 are unlikely to
106 /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
107 /// experiments for now and focus on documenting the implications of a large width.
108 pub width: u32,
109 pub ret_activity: DiffActivity,
110 pub input_activity: Vec<DiffActivity>,
111}
112
113impl AutoDiffAttrs {
114 pub fn has_primal_ret(&self) -> bool {
115 matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
116 }
117}
118
119impl DiffMode {
120 pub fn is_rev(&self) -> bool {
121 matches!(self, DiffMode::Reverse)
122 }
123 pub fn is_fwd(&self) -> bool {
124 matches!(self, DiffMode::Forward)
125 }
126}
127
128impl Display for DiffMode {
129 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
130 match self {
131 DiffMode::Error => write!(f, "Error"),
132 DiffMode::Source => write!(f, "Source"),
133 DiffMode::Forward => write!(f, "Forward"),
134 DiffMode::Reverse => write!(f, "Reverse"),
135 }
136 }
137}
138
139/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
140/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
141/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
142/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
143/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
144pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
145 if activity == DiffActivity::None {
146 // Only valid if primal returns (), but we can't check that here.
147 return true;
148 }
149 match mode {
150 DiffMode::Error => false,
151 DiffMode::Source => false,
152 DiffMode::Forward => activity.is_dual_or_const(),
153 DiffMode::Reverse => {
154 activity == DiffActivity::Const
155 || activity == DiffActivity::Active
156 || activity == DiffActivity::ActiveOnly
157 }
158 }
159}
160
161/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
162/// for the given argument, but we generally can't know the size of such a type.
163/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
164/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
165/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
166/// users here from marking scalars as Duplicated, due to type aliases.
167pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {
168 use DiffActivity::*;
169 // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
170 // Dual variants also support all types.
171 if activity.is_dual_or_const() {
172 return true;
173 }
174 // FIXME(ZuseZ4) We should make this more robust to also
175 // handle type aliases. Once that is done, we can be more restrictive here.
176 if matches!(activity, Active | ActiveOnly) {
177 return true;
178 }
179 matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
180 && matches!(activity, Duplicated | DuplicatedOnly)
181}
182pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
183 use DiffActivity::*;
184 return match mode {
185 DiffMode::Error => false,
186 DiffMode::Source => false,
187 DiffMode::Forward => activity.is_dual_or_const(),
188 DiffMode::Reverse => {
189 matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
190 }
191 };
192}
193
194impl Display for DiffActivity {
195 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
196 match self {
197 DiffActivity::None => write!(f, "None"),
198 DiffActivity::Const => write!(f, "Const"),
199 DiffActivity::Active => write!(f, "Active"),
200 DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
201 DiffActivity::Dual => write!(f, "Dual"),
202 DiffActivity::Dualv => write!(f, "Dualv"),
203 DiffActivity::DualOnly => write!(f, "DualOnly"),
204 DiffActivity::DualvOnly => write!(f, "DualvOnly"),
205 DiffActivity::Duplicated => write!(f, "Duplicated"),
206 DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
207 DiffActivity::FakeActivitySize(s) => write!(f, "FakeActivitySize({:?})", s),
208 }
209 }
210}
211
212impl FromStr for DiffMode {
213 type Err = ();
214
215 fn from_str(s: &str) -> Result<DiffMode, ()> {
216 match s {
217 "Error" => Ok(DiffMode::Error),
218 "Source" => Ok(DiffMode::Source),
219 "Forward" => Ok(DiffMode::Forward),
220 "Reverse" => Ok(DiffMode::Reverse),
221 _ => Err(()),
222 }
223 }
224}
225impl FromStr for DiffActivity {
226 type Err = ();
227
228 fn from_str(s: &str) -> Result<DiffActivity, ()> {
229 match s {
230 "None" => Ok(DiffActivity::None),
231 "Active" => Ok(DiffActivity::Active),
232 "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
233 "Const" => Ok(DiffActivity::Const),
234 "Dual" => Ok(DiffActivity::Dual),
235 "Dualv" => Ok(DiffActivity::Dualv),
236 "DualOnly" => Ok(DiffActivity::DualOnly),
237 "DualvOnly" => Ok(DiffActivity::DualvOnly),
238 "Duplicated" => Ok(DiffActivity::Duplicated),
239 "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
240 _ => Err(()),
241 }
242 }
243}
244
245impl AutoDiffAttrs {
246 pub fn has_ret_activity(&self) -> bool {
247 self.ret_activity != DiffActivity::None
248 }
249 pub fn has_active_only_ret(&self) -> bool {
250 self.ret_activity == DiffActivity::ActiveOnly
251 }
252
253 pub const fn error() -> Self {
254 AutoDiffAttrs {
255 mode: DiffMode::Error,
256 width: 0,
257 ret_activity: DiffActivity::None,
258 input_activity: Vec::new(),
259 }
260 }
261 pub fn source() -> Self {
262 AutoDiffAttrs {
263 mode: DiffMode::Source,
264 width: 0,
265 ret_activity: DiffActivity::None,
266 input_activity: Vec::new(),
267 }
268 }
269
270 pub fn is_active(&self) -> bool {
271 self.mode != DiffMode::Error
272 }
273
274 pub fn is_source(&self) -> bool {
275 self.mode == DiffMode::Source
276 }
277 pub fn apply_autodiff(&self) -> bool {
278 !matches!(self.mode, DiffMode::Error | DiffMode::Source)
279 }
280
281 pub fn into_item(
282 self,
283 source: String,
284 target: String,
285 inputs: Vec<TypeTree>,
286 output: TypeTree,
287 ) -> AutoDiffItem {
288 AutoDiffItem { source, target, inputs, output, attrs: self }
289 }
290}
291
292impl fmt::Display for AutoDiffItem {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 write!(f, "Differentiating {} -> {}", self.source, self.target)?;
295 write!(f, " with attributes: {:?}", self.attrs)?;
296 write!(f, " with inputs: {:?}", self.inputs)?;
297 write!(f, " with output: {:?}", self.output)
298 }
299}