rustc_ast/expand/
autodiff_attrs.rs

1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use crate::expand::{Decodable, Encodable, HashStable_Generic};
10use crate::ptr::P;
11use crate::{Ty, TyKind};
12
13/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
14/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
15/// are a hack to support higher order derivatives. We need to compute first order derivatives
16/// before we compute second order derivatives, otherwise we would differentiate our placeholder
17/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
18/// as it's already done in the C++ and Julia frontend of Enzyme.
19///
20/// (FIXME) remove *First variants.
21/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
22/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
23#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
24pub enum DiffMode {
25    /// No autodiff is applied (used during error handling).
26    Error,
27    /// The primal function which we will differentiate.
28    Source,
29    /// The target function, to be created using forward mode AD.
30    Forward,
31    /// The target function, to be created using reverse mode AD.
32    Reverse,
33}
34
35/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
36/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
37/// we add to the previous shadow value. To not surprise users, we picked different names.
38/// Dual numbers is also a quite well known name for forward mode AD types.
39#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
40pub enum DiffActivity {
41    /// Implicit or Explicit () return type, so a special case of Const.
42    None,
43    /// Don't compute derivatives with respect to this input/output.
44    Const,
45    /// Reverse Mode, Compute derivatives for this scalar input/output.
46    Active,
47    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
48    /// the original return value.
49    ActiveOnly,
50    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
51    /// with it.
52    Dual,
53    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
54    /// with it. Drop the code which updates the original input/output for maximum performance.
55    DualOnly,
56    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
57    Duplicated,
58    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
59    /// Drop the code which updates the original input for maximum performance.
60    DuplicatedOnly,
61    /// All Integers must be Const, but these are used to mark the integer which represents the
62    /// length of a slice/vec. This is used for safety checks on slices.
63    FakeActivitySize,
64}
65/// We generate one of these structs for each `#[autodiff(...)]` attribute.
66#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
67pub struct AutoDiffItem {
68    /// The name of the function getting differentiated
69    pub source: String,
70    /// The name of the function being generated
71    pub target: String,
72    pub attrs: AutoDiffAttrs,
73}
74
75#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
76pub struct AutoDiffAttrs {
77    /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
78    /// e.g. in the [JAX
79    /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
80    pub mode: DiffMode,
81    pub ret_activity: DiffActivity,
82    pub input_activity: Vec<DiffActivity>,
83}
84
85impl DiffMode {
86    pub fn is_rev(&self) -> bool {
87        matches!(self, DiffMode::Reverse)
88    }
89    pub fn is_fwd(&self) -> bool {
90        matches!(self, DiffMode::Forward)
91    }
92}
93
94impl Display for DiffMode {
95    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
96        match self {
97            DiffMode::Error => write!(f, "Error"),
98            DiffMode::Source => write!(f, "Source"),
99            DiffMode::Forward => write!(f, "Forward"),
100            DiffMode::Reverse => write!(f, "Reverse"),
101        }
102    }
103}
104
105/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
106/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
107/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
108/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
109/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
110pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
111    if activity == DiffActivity::None {
112        // Only valid if primal returns (), but we can't check that here.
113        return true;
114    }
115    match mode {
116        DiffMode::Error => false,
117        DiffMode::Source => false,
118        DiffMode::Forward => {
119            activity == DiffActivity::Dual
120                || activity == DiffActivity::DualOnly
121                || activity == DiffActivity::Const
122        }
123        DiffMode::Reverse => {
124            activity == DiffActivity::Const
125                || activity == DiffActivity::Active
126                || activity == DiffActivity::ActiveOnly
127        }
128    }
129}
130
131/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
132/// for the given argument, but we generally can't know the size of such a type.
133/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
134/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
135/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
136/// users here from marking scalars as Duplicated, due to type aliases.
137pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
138    use DiffActivity::*;
139    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
140    if matches!(activity, Const) {
141        return true;
142    }
143    if matches!(activity, Dual | DualOnly) {
144        return true;
145    }
146    // FIXME(ZuseZ4) We should make this more robust to also
147    // handle type aliases. Once that is done, we can be more restrictive here.
148    if matches!(activity, Active | ActiveOnly) {
149        return true;
150    }
151    matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
152        && matches!(activity, Duplicated | DuplicatedOnly)
153}
154pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
155    use DiffActivity::*;
156    return match mode {
157        DiffMode::Error => false,
158        DiffMode::Source => false,
159        DiffMode::Forward => {
160            matches!(activity, Dual | DualOnly | Const)
161        }
162        DiffMode::Reverse => {
163            matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
164        }
165    };
166}
167
168impl Display for DiffActivity {
169    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
170        match self {
171            DiffActivity::None => write!(f, "None"),
172            DiffActivity::Const => write!(f, "Const"),
173            DiffActivity::Active => write!(f, "Active"),
174            DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
175            DiffActivity::Dual => write!(f, "Dual"),
176            DiffActivity::DualOnly => write!(f, "DualOnly"),
177            DiffActivity::Duplicated => write!(f, "Duplicated"),
178            DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
179            DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
180        }
181    }
182}
183
184impl FromStr for DiffMode {
185    type Err = ();
186
187    fn from_str(s: &str) -> Result<DiffMode, ()> {
188        match s {
189            "Error" => Ok(DiffMode::Error),
190            "Source" => Ok(DiffMode::Source),
191            "Forward" => Ok(DiffMode::Forward),
192            "Reverse" => Ok(DiffMode::Reverse),
193            _ => Err(()),
194        }
195    }
196}
197impl FromStr for DiffActivity {
198    type Err = ();
199
200    fn from_str(s: &str) -> Result<DiffActivity, ()> {
201        match s {
202            "None" => Ok(DiffActivity::None),
203            "Active" => Ok(DiffActivity::Active),
204            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
205            "Const" => Ok(DiffActivity::Const),
206            "Dual" => Ok(DiffActivity::Dual),
207            "DualOnly" => Ok(DiffActivity::DualOnly),
208            "Duplicated" => Ok(DiffActivity::Duplicated),
209            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
210            _ => Err(()),
211        }
212    }
213}
214
215impl AutoDiffAttrs {
216    pub fn has_ret_activity(&self) -> bool {
217        self.ret_activity != DiffActivity::None
218    }
219    pub fn has_active_only_ret(&self) -> bool {
220        self.ret_activity == DiffActivity::ActiveOnly
221    }
222
223    pub const fn error() -> Self {
224        AutoDiffAttrs {
225            mode: DiffMode::Error,
226            ret_activity: DiffActivity::None,
227            input_activity: Vec::new(),
228        }
229    }
230    pub fn source() -> Self {
231        AutoDiffAttrs {
232            mode: DiffMode::Source,
233            ret_activity: DiffActivity::None,
234            input_activity: Vec::new(),
235        }
236    }
237
238    pub fn is_active(&self) -> bool {
239        self.mode != DiffMode::Error
240    }
241
242    pub fn is_source(&self) -> bool {
243        self.mode == DiffMode::Source
244    }
245    pub fn apply_autodiff(&self) -> bool {
246        !matches!(self.mode, DiffMode::Error | DiffMode::Source)
247    }
248
249    pub fn into_item(self, source: String, target: String) -> AutoDiffItem {
250        AutoDiffItem { source, target, attrs: self }
251    }
252}
253
254impl fmt::Display for AutoDiffItem {
255    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
256        write!(f, "Differentiating {} -> {}", self.source, self.target)?;
257        write!(f, " with attributes: {:?}", self.attrs)
258    }
259}