rustc_ast/expand/
autodiff_attrs.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
//! is the function to which the autodiff attribute is applied, and the target is the function
//! getting generated by us (with a name given by the user as the first autodiff arg).

use std::fmt::{self, Display, Formatter};
use std::str::FromStr;

use crate::expand::typetree::TypeTree;
use crate::expand::{Decodable, Encodable, HashStable_Generic};
use crate::ptr::P;
use crate::{Ty, TyKind};

/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
/// are a hack to support higher order derivatives. We need to compute first order derivatives
/// before we compute second order derivatives, otherwise we would differentiate our placeholder
/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
/// as it's already done in the C++ and Julia frontend of Enzyme.
///
/// (FIXME) remove *First variants.
/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffMode {
    /// No autodiff is applied (used during error handling).
    Error,
    /// The primal function which we will differentiate.
    Source,
    /// The target function, to be created using forward mode AD.
    Forward,
    /// The target function, to be created using reverse mode AD.
    Reverse,
    /// The target function, to be created using forward mode AD.
    /// This target function will also be used as a source for higher order derivatives,
    /// so compute it before all Forward/Reverse targets and optimize it through llvm.
    ForwardFirst,
    /// The target function, to be created using reverse mode AD.
    /// This target function will also be used as a source for higher order derivatives,
    /// so compute it before all Forward/Reverse targets and optimize it through llvm.
    ReverseFirst,
}

/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
/// we add to the previous shadow value. To not surprise users, we picked different names.
/// Dual numbers is also a quite well known name for forward mode AD types.
#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum DiffActivity {
    /// Implicit or Explicit () return type, so a special case of Const.
    None,
    /// Don't compute derivatives with respect to this input/output.
    Const,
    /// Reverse Mode, Compute derivatives for this scalar input/output.
    Active,
    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
    /// the original return value.
    ActiveOnly,
    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
    /// with it.
    Dual,
    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
    /// with it. Drop the code which updates the original input/output for maximum performance.
    DualOnly,
    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
    Duplicated,
    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
    /// Drop the code which updates the original input for maximum performance.
    DuplicatedOnly,
    /// All Integers must be Const, but these are used to mark the integer which represents the
    /// length of a slice/vec. This is used for safety checks on slices.
    FakeActivitySize,
}
/// We generate one of these structs for each `#[autodiff(...)]` attribute.
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffItem {
    /// The name of the function getting differentiated
    pub source: String,
    /// The name of the function being generated
    pub target: String,
    pub attrs: AutoDiffAttrs,
    /// Describe the memory layout of input types
    pub inputs: Vec<TypeTree>,
    /// Describe the memory layout of the output type
    pub output: TypeTree,
}
#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct AutoDiffAttrs {
    /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
    /// e.g. in the [JAX
    /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
    pub mode: DiffMode,
    pub ret_activity: DiffActivity,
    pub input_activity: Vec<DiffActivity>,
}

impl DiffMode {
    pub fn is_rev(&self) -> bool {
        matches!(self, DiffMode::Reverse | DiffMode::ReverseFirst)
    }
    pub fn is_fwd(&self) -> bool {
        matches!(self, DiffMode::Forward | DiffMode::ForwardFirst)
    }
}

impl Display for DiffMode {
    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
        match self {
            DiffMode::Error => write!(f, "Error"),
            DiffMode::Source => write!(f, "Source"),
            DiffMode::Forward => write!(f, "Forward"),
            DiffMode::Reverse => write!(f, "Reverse"),
            DiffMode::ForwardFirst => write!(f, "ForwardFirst"),
            DiffMode::ReverseFirst => write!(f, "ReverseFirst"),
        }
    }
}

/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
    if activity == DiffActivity::None {
        // Only valid if primal returns (), but we can't check that here.
        return true;
    }
    match mode {
        DiffMode::Error => false,
        DiffMode::Source => false,
        DiffMode::Forward | DiffMode::ForwardFirst => {
            activity == DiffActivity::Dual
                || activity == DiffActivity::DualOnly
                || activity == DiffActivity::Const
        }
        DiffMode::Reverse | DiffMode::ReverseFirst => {
            activity == DiffActivity::Const
                || activity == DiffActivity::Active
                || activity == DiffActivity::ActiveOnly
        }
    }
}

/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
/// for the given argument, but we generally can't know the size of such a type.
/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
/// users here from marking scalars as Duplicated, due to type aliases.
pub fn valid_ty_for_activity(ty: &P<Ty>, activity: DiffActivity) -> bool {
    use DiffActivity::*;
    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
    if matches!(activity, Const) {
        return true;
    }
    if matches!(activity, Dual | DualOnly) {
        return true;
    }
    // FIXME(ZuseZ4) We should make this more robust to also
    // handle type aliases. Once that is done, we can be more restrictive here.
    if matches!(activity, Active | ActiveOnly) {
        return true;
    }
    matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
        && matches!(activity, Duplicated | DuplicatedOnly)
}
pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
    use DiffActivity::*;
    return match mode {
        DiffMode::Error => false,
        DiffMode::Source => false,
        DiffMode::Forward | DiffMode::ForwardFirst => {
            matches!(activity, Dual | DualOnly | Const)
        }
        DiffMode::Reverse | DiffMode::ReverseFirst => {
            matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
        }
    };
}

impl Display for DiffActivity {
    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
        match self {
            DiffActivity::None => write!(f, "None"),
            DiffActivity::Const => write!(f, "Const"),
            DiffActivity::Active => write!(f, "Active"),
            DiffActivity::ActiveOnly => write!(f, "ActiveOnly"),
            DiffActivity::Dual => write!(f, "Dual"),
            DiffActivity::DualOnly => write!(f, "DualOnly"),
            DiffActivity::Duplicated => write!(f, "Duplicated"),
            DiffActivity::DuplicatedOnly => write!(f, "DuplicatedOnly"),
            DiffActivity::FakeActivitySize => write!(f, "FakeActivitySize"),
        }
    }
}

impl FromStr for DiffMode {
    type Err = ();

    fn from_str(s: &str) -> Result<DiffMode, ()> {
        match s {
            "Error" => Ok(DiffMode::Error),
            "Source" => Ok(DiffMode::Source),
            "Forward" => Ok(DiffMode::Forward),
            "Reverse" => Ok(DiffMode::Reverse),
            "ForwardFirst" => Ok(DiffMode::ForwardFirst),
            "ReverseFirst" => Ok(DiffMode::ReverseFirst),
            _ => Err(()),
        }
    }
}
impl FromStr for DiffActivity {
    type Err = ();

    fn from_str(s: &str) -> Result<DiffActivity, ()> {
        match s {
            "None" => Ok(DiffActivity::None),
            "Active" => Ok(DiffActivity::Active),
            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
            "Const" => Ok(DiffActivity::Const),
            "Dual" => Ok(DiffActivity::Dual),
            "DualOnly" => Ok(DiffActivity::DualOnly),
            "Duplicated" => Ok(DiffActivity::Duplicated),
            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
            _ => Err(()),
        }
    }
}

impl AutoDiffAttrs {
    pub fn has_ret_activity(&self) -> bool {
        self.ret_activity != DiffActivity::None
    }
    pub fn has_active_only_ret(&self) -> bool {
        self.ret_activity == DiffActivity::ActiveOnly
    }

    pub fn error() -> Self {
        AutoDiffAttrs {
            mode: DiffMode::Error,
            ret_activity: DiffActivity::None,
            input_activity: Vec::new(),
        }
    }
    pub fn source() -> Self {
        AutoDiffAttrs {
            mode: DiffMode::Source,
            ret_activity: DiffActivity::None,
            input_activity: Vec::new(),
        }
    }

    pub fn is_active(&self) -> bool {
        self.mode != DiffMode::Error
    }

    pub fn is_source(&self) -> bool {
        self.mode == DiffMode::Source
    }
    pub fn apply_autodiff(&self) -> bool {
        !matches!(self.mode, DiffMode::Error | DiffMode::Source)
    }

    pub fn into_item(
        self,
        source: String,
        target: String,
        inputs: Vec<TypeTree>,
        output: TypeTree,
    ) -> AutoDiffItem {
        AutoDiffItem { source, target, inputs, output, attrs: self }
    }
}

impl fmt::Display for AutoDiffItem {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        write!(f, "Differentiating {} -> {}", self.source, self.target)?;
        write!(f, " with attributes: {:?}", self.attrs)?;
        write!(f, " with inputs: {:?}", self.inputs)?;
        write!(f, " with output: {:?}", self.output)
    }
}