rustc_ast/expand/autodiff_attrs.rs
1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use crate::expand::{Decodable, Encodable, HashStable_Generic};
10use crate::ptr::P;
11use crate::{Ty, TyKind};
12
13/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
14/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
15/// are a hack to support higher order derivatives. We need to compute first order derivatives
16/// before we compute second order derivatives, otherwise we would differentiate our placeholder
17/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
18/// as it's already done in the C++ and Julia frontend of Enzyme.
19///
20/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
21/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
22#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
23pub enum DiffMode {
24 /// No autodiff is applied (used during error handling).
25 Error,
26 /// The primal function which we will differentiate.
27 Source,
28 /// The target function, to be created using forward mode AD.
29 Forward,
30 /// The target function, to be created using reverse mode AD.
31 Reverse,
32}
33
34/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
35/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
36/// we add to the previous shadow value. To not surprise users, we picked different names.
37/// Dual numbers is also a quite well known name for forward mode AD types.
38#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
39pub enum DiffActivity {
40 /// Implicit or Explicit () return type, so a special case of Const.
41 None,
42 /// Don't compute derivatives with respect to this input/output.
43 Const,
44 /// Reverse Mode, Compute derivatives for this scalar input/output.
45 Active,
46 /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
47 /// the original return value.
48 ActiveOnly,
49 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
50 /// with it.
51 Dual,
52 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53 /// with it. Drop the code which updates the original input/output for maximum performance.
54 DualOnly,
55 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
56 Duplicated,
57 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
58 /// Drop the code which updates the original input for maximum performance.
59 DuplicatedOnly,
60 /// All Integers must be Const, but these are used to mark the integer which represents the
61 /// length of a slice/vec. This is used for safety checks on slices.
62 FakeActivitySize,
63}
64/// We generate one of these structs for each `#[autodiff(...)]` attribute.
65#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
66pub struct AutoDiffItem {
67 /// The name of the function getting differentiated
68 pub source: String,
69 /// The name of the function being generated
70 pub target: String,
71 pub attrs: AutoDiffAttrs,
72}
73
74#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
75pub struct AutoDiffAttrs {
76 /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
77 /// e.g. in the [JAX
78 /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
79 pub mode: DiffMode,
80 pub ret_activity: DiffActivity,
81 pub input_activity: Vec<DiffActivity>,
82}
83
84impl DiffMode {
85 pub fn is_rev(&self) -> bool {
86 matches!(self, DiffMode::Reverse)
87 }
88 pub fn is_fwd(&self) -> bool {
89 matches!(self, DiffMode::Forward)
90 }
91}
92
93impl Display for DiffMode {
94 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
95 match self {
96 DiffMode::Error => write!(f, "Error"),
97 DiffMode::Source => write!(f, "Source"),
98 DiffMode::Forward => write!(f, "Forward"),
99 DiffMode::Reverse => write!(f, "Reverse"),
100 }
101 }
102}
103
104/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
105/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
106/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
107/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
108/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
109pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
110 if activity == DiffActivity::None {
111 // Only valid if primal returns (), but we can't check that here.
112 return true;
113 }
114 match mode {
115 DiffMode::Error => false,
116 DiffMode::Source => false,
117 DiffMode::Forward => {
118 activity == DiffActivity::Dual
119 || activity == DiffActivity::DualOnly
120 || activity == DiffActivity::Const
121 }
122 DiffMode::Reverse => {
123 activity == DiffActivity::Const
124 || activity == DiffActivity::Active
125 || activity == DiffActivity::ActiveOnly
126 }
127 }
128}
129
130/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
131/// for the given argument, but we generally can't know the size of such a type.
132/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
133/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
134/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
135/// users here from marking scalars as Duplicated, due to type aliases.
136pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
137 use DiffActivity::*;
138 // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
139 if matches!(activity, Const) {
140 return true;
141 }
142 if matches!(activity, Dual | DualOnly) {
143 return true;
144 }
145 // FIXME(ZuseZ4) We should make this more robust to also
146 // handle type aliases. Once that is done, we can be more restrictive here.
147 if matches!(activity, Active | ActiveOnly) {
148 return true;
149 }
150 matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
151 && matches!(activity, Duplicated | DuplicatedOnly)
152}
153pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
154 use DiffActivity::*;
155 return match mode {
156 DiffMode::Error => false,
157 DiffMode::Source => false,
158 DiffMode::Forward => {
159 matches!(activity, Dual | DualOnly | Const)
160 }
161 DiffMode::Reverse => {
162 matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
163 }
164 };
165}
166
167impl Display for DiffActivity {
168 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
169 match self {
170 DiffActivity::None => write!(f, "None"),
171 DiffActivity::Const => write!(f, "Const"),
172 DiffActivity::Active => write!(f, "Active"),
173 DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
174 DiffActivity::Dual => write!(f, "Dual"),
175 DiffActivity::DualOnly => write!(f, "DualOnly"),
176 DiffActivity::Duplicated => write!(f, "Duplicated"),
177 DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
178 DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
179 }
180 }
181}
182
183impl FromStr for DiffMode {
184 type Err = ();
185
186 fn from_str(s: &str) -> Result<DiffMode, ()> {
187 match s {
188 "Error" => Ok(DiffMode::Error),
189 "Source" => Ok(DiffMode::Source),
190 "Forward" => Ok(DiffMode::Forward),
191 "Reverse" => Ok(DiffMode::Reverse),
192 _ => Err(()),
193 }
194 }
195}
196impl FromStr for DiffActivity {
197 type Err = ();
198
199 fn from_str(s: &str) -> Result<DiffActivity, ()> {
200 match s {
201 "None" => Ok(DiffActivity::None),
202 "Active" => Ok(DiffActivity::Active),
203 "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
204 "Const" => Ok(DiffActivity::Const),
205 "Dual" => Ok(DiffActivity::Dual),
206 "DualOnly" => Ok(DiffActivity::DualOnly),
207 "Duplicated" => Ok(DiffActivity::Duplicated),
208 "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
209 _ => Err(()),
210 }
211 }
212}
213
214impl AutoDiffAttrs {
215 pub fn has_ret_activity(&self) -> bool {
216 self.ret_activity != DiffActivity::None
217 }
218 pub fn has_active_only_ret(&self) -> bool {
219 self.ret_activity == DiffActivity::ActiveOnly
220 }
221
222 pub const fn error() -> Self {
223 AutoDiffAttrs {
224 mode: DiffMode::Error,
225 ret_activity: DiffActivity::None,
226 input_activity: Vec::new(),
227 }
228 }
229 pub fn source() -> Self {
230 AutoDiffAttrs {
231 mode: DiffMode::Source,
232 ret_activity: DiffActivity::None,
233 input_activity: Vec::new(),
234 }
235 }
236
237 pub fn is_active(&self) -> bool {
238 self.mode != DiffMode::Error
239 }
240
241 pub fn is_source(&self) -> bool {
242 self.mode == DiffMode::Source
243 }
244 pub fn apply_autodiff(&self) -> bool {
245 !matches!(self.mode, DiffMode::Error | DiffMode::Source)
246 }
247
248 pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
249 AutoDiffItem { source, target, attrs: self }
250 }
251}
252
253impl fmt::Display for AutoDiffItem {
254 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
255 write!(f, "Differentiating {} -> {}", self.source, self.target)?;
256 write!(f, " with attributes: {:?}", self.attrs)
257 }
258}