rustc_ast/expand/autodiff_attrs.rs
1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use crate::expand::{Decodable, Encodable, HashStable_Generic};
10use crate::ptr::P;
11use crate::{Ty, TyKind};
12
13/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
14/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
15/// are a hack to support higher order derivatives. We need to compute first order derivatives
16/// before we compute second order derivatives, otherwise we would differentiate our placeholder
17/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
18/// as it's already done in the C++ and Julia frontend of Enzyme.
19///
20/// (FIXME) remove *First variants.
21/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
22/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
23#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
24pub enum DiffMode {
25 /// No autodiff is applied (used during error handling).
26 Error,
27 /// The primal function which we will differentiate.
28 Source,
29 /// The target function, to be created using forward mode AD.
30 Forward,
31 /// The target function, to be created using reverse mode AD.
32 Reverse,
33}
34
35/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
36/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
37/// we add to the previous shadow value. To not surprise users, we picked different names.
38/// Dual numbers is also a quite well known name for forward mode AD types.
39#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
40pub enum DiffActivity {
41 /// Implicit or Explicit () return type, so a special case of Const.
42 None,
43 /// Don't compute derivatives with respect to this input/output.
44 Const,
45 /// Reverse Mode, Compute derivatives for this scalar input/output.
46 Active,
47 /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
48 /// the original return value.
49 ActiveOnly,
50 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
51 /// with it.
52 Dual,
53 /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
54 /// with it. Drop the code which updates the original input/output for maximum performance.
55 DualOnly,
56 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
57 Duplicated,
58 /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
59 /// Drop the code which updates the original input for maximum performance.
60 DuplicatedOnly,
61 /// All Integers must be Const, but these are used to mark the integer which represents the
62 /// length of a slice/vec. This is used for safety checks on slices.
63 FakeActivitySize,
64}
65/// We generate one of these structs for each `#[autodiff(...)]` attribute.
66#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
67pub struct AutoDiffItem {
68 /// The name of the function getting differentiated
69 pub source: String,
70 /// The name of the function being generated
71 pub target: String,
72 pub attrs: AutoDiffAttrs,
73}
74
75#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
76pub struct AutoDiffAttrs {
77 /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
78 /// e.g. in the [JAX
79 /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
80 pub mode: DiffMode,
81 pub ret_activity: DiffActivity,
82 pub input_activity: Vec<DiffActivity>,
83}
84
85impl DiffMode {
86 pub fn is_rev(&self) -> bool {
87 matches!(self, DiffMode::Reverse)
88 }
89 pub fn is_fwd(&self) -> bool {
90 matches!(self, DiffMode::Forward)
91 }
92}
93
94impl Display for DiffMode {
95 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
96 match self {
97 DiffMode::Error => write!(f, "Error"),
98 DiffMode::Source => write!(f, "Source"),
99 DiffMode::Forward => write!(f, "Forward"),
100 DiffMode::Reverse => write!(f, "Reverse"),
101 }
102 }
103}
104
105/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
106/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
107/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
108/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
109/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
110pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
111 if activity == DiffActivity::None {
112 // Only valid if primal returns (), but we can't check that here.
113 return true;
114 }
115 match mode {
116 DiffMode::Error => false,
117 DiffMode::Source => false,
118 DiffMode::Forward => {
119 activity == DiffActivity::Dual
120 || activity == DiffActivity::DualOnly
121 || activity == DiffActivity::Const
122 }
123 DiffMode::Reverse => {
124 activity == DiffActivity::Const
125 || activity == DiffActivity::Active
126 || activity == DiffActivity::ActiveOnly
127 }
128 }
129}
130
131/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
132/// for the given argument, but we generally can't know the size of such a type.
133/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
134/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
135/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
136/// users here from marking scalars as Duplicated, due to type aliases.
137pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
138 use DiffActivity::*;
139 // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
140 if matches!(activity, Const) {
141 return true;
142 }
143 if matches!(activity, Dual | DualOnly) {
144 return true;
145 }
146 // FIXME(ZuseZ4) We should make this more robust to also
147 // handle type aliases. Once that is done, we can be more restrictive here.
148 if matches!(activity, Active | ActiveOnly) {
149 return true;
150 }
151 matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
152 && matches!(activity, Duplicated | DuplicatedOnly)
153}
154pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
155 use DiffActivity::*;
156 return match mode {
157 DiffMode::Error => false,
158 DiffMode::Source => false,
159 DiffMode::Forward => {
160 matches!(activity, Dual | DualOnly | Const)
161 }
162 DiffMode::Reverse => {
163 matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
164 }
165 };
166}
167
168impl Display for DiffActivity {
169 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170 match self {
171 DiffActivity::None => write!(f, "None"),
172 DiffActivity::Const => write!(f, "Const"),
173 DiffActivity::Active => write!(f, "Active"),
174 DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
175 DiffActivity::Dual => write!(f, "Dual"),
176 DiffActivity::DualOnly => write!(f, "DualOnly"),
177 DiffActivity::Duplicated => write!(f, "Duplicated"),
178 DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
179 DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
180 }
181 }
182}
183
184impl FromStr for DiffMode {
185 type Err = ();
186
187 fn from_str(s: &str) -> Result<DiffMode, ()> {
188 match s {
189 "Error" => Ok(DiffMode::Error),
190 "Source" => Ok(DiffMode::Source),
191 "Forward" => Ok(DiffMode::Forward),
192 "Reverse" => Ok(DiffMode::Reverse),
193 _ => Err(()),
194 }
195 }
196}
197impl FromStr for DiffActivity {
198 type Err = ();
199
200 fn from_str(s: &str) -> Result<DiffActivity, ()> {
201 match s {
202 "None" => Ok(DiffActivity::None),
203 "Active" => Ok(DiffActivity::Active),
204 "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
205 "Const" => Ok(DiffActivity::Const),
206 "Dual" => Ok(DiffActivity::Dual),
207 "DualOnly" => Ok(DiffActivity::DualOnly),
208 "Duplicated" => Ok(DiffActivity::Duplicated),
209 "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
210 _ => Err(()),
211 }
212 }
213}
214
215impl AutoDiffAttrs {
216 pub fn has_ret_activity(&self) -> bool {
217 self.ret_activity != DiffActivity::None
218 }
219 pub fn has_active_only_ret(&self) -> bool {
220 self.ret_activity == DiffActivity::ActiveOnly
221 }
222
223 pub const fn error() -> Self {
224 AutoDiffAttrs {
225 mode: DiffMode::Error,
226 ret_activity: DiffActivity::None,
227 input_activity: Vec::new(),
228 }
229 }
230 pub fn source() -> Self {
231 AutoDiffAttrs {
232 mode: DiffMode::Source,
233 ret_activity: DiffActivity::None,
234 input_activity: Vec::new(),
235 }
236 }
237
238 pub fn is_active(&self) -> bool {
239 self.mode != DiffMode::Error
240 }
241
242 pub fn is_source(&self) -> bool {
243 self.mode == DiffMode::Source
244 }
245 pub fn apply_autodiff(&self) -> bool {
246 !matches!(self.mode, DiffMode::Error | DiffMode::Source)
247 }
248
249 pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
250 AutoDiffItem { source, target, attrs: self }
251 }
252}
253
254impl fmt::Display for AutoDiffItem {
255 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256 write!(f, "Differentiating {} -> {}", self.source, self.target)?;
257 write!(f, " with attributes: {:?}", self.attrs)
258 }
259}