Skip to main content

rustc_ast/expand/
autodiff_attrs.rs

1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an `RustcAutodiff` which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use rustc_span::{Symbol, sym};
10
11use crate::expand::{Decodable, Encodable, HashStable_Generic};
12use crate::{Ty, TyKind};
13
14/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
15/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
16/// are a hack to support higher order derivatives. We need to compute first order derivatives
17/// before we compute second order derivatives, otherwise we would differentiate our placeholder
18/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
19/// as it's already done in the C++ and Julia frontend of Enzyme.
20///
21/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
22/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
23#[derive(#[automatically_derived]
impl ::core::clone::Clone for DiffMode {
    #[inline]
    fn clone(&self) -> DiffMode { *self }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for DiffMode { }Copy, #[automatically_derived]
impl ::core::cmp::Eq for DiffMode {
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_fields_are_eq(&self) {}
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for DiffMode {
    #[inline]
    fn eq(&self, other: &DiffMode) -> bool {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
        __self_discr == __arg1_discr
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for DiffMode {
            fn encode(&self, __encoder: &mut __E) {
                let disc =
                    match *self {
                        DiffMode::Error => { 0usize }
                        DiffMode::Source => { 1usize }
                        DiffMode::Forward => { 2usize }
                        DiffMode::Reverse => { 3usize }
                    };
                ::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
                match *self {
                    DiffMode::Error => {}
                    DiffMode::Source => {}
                    DiffMode::Forward => {}
                    DiffMode::Reverse => {}
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for DiffMode {
            fn decode(__decoder: &mut __D) -> Self {
                match ::rustc_serialize::Decoder::read_u8(__decoder) as usize
                    {
                    0usize => { DiffMode::Error }
                    1usize => { DiffMode::Source }
                    2usize => { DiffMode::Forward }
                    3usize => { DiffMode::Reverse }
                    n => {
                        ::core::panicking::panic_fmt(format_args!("invalid enum variant tag while decoding `DiffMode`, expected 0..4, actual {0}",
                                n));
                    }
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for DiffMode {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::write_str(f,
            match self {
                DiffMode::Error => "Error",
                DiffMode::Source => "Source",
                DiffMode::Forward => "Forward",
                DiffMode::Reverse => "Reverse",
            })
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for DiffMode where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                ::std::mem::discriminant(self).hash_stable(__hcx, __hasher);
                match *self {
                    DiffMode::Error => {}
                    DiffMode::Source => {}
                    DiffMode::Forward => {}
                    DiffMode::Reverse => {}
                }
            }
        }
    };HashStable_Generic)]
24pub enum DiffMode {
25    /// No autodiff is applied (used during error handling).
26    Error,
27    /// The primal function which we will differentiate.
28    Source,
29    /// The target function, to be created using forward mode AD.
30    Forward,
31    /// The target function, to be created using reverse mode AD.
32    Reverse,
33}
34
35impl DiffMode {
36    pub fn all_modes() -> &'static [Symbol] {
37        &[sym::Source, sym::Forward, sym::Reverse]
38    }
39}
40
41/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
42/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
43/// we add to the previous shadow value. To not surprise users, we picked different names.
44/// Dual numbers is also a quite well known name for forward mode AD types.
45#[derive(#[automatically_derived]
impl ::core::clone::Clone for DiffActivity {
    #[inline]
    fn clone(&self) -> DiffActivity {
        let _: ::core::clone::AssertParamIsClone<Option<u32>>;
        *self
    }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for DiffActivity { }Copy, #[automatically_derived]
impl ::core::cmp::Eq for DiffActivity {
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_fields_are_eq(&self) {
        let _: ::core::cmp::AssertParamIsEq<Option<u32>>;
    }
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for DiffActivity {
    #[inline]
    fn eq(&self, other: &DiffActivity) -> bool {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
        __self_discr == __arg1_discr &&
            match (self, other) {
                (DiffActivity::FakeActivitySize(__self_0),
                    DiffActivity::FakeActivitySize(__arg1_0)) =>
                    __self_0 == __arg1_0,
                _ => true,
            }
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for DiffActivity {
            fn encode(&self, __encoder: &mut __E) {
                let disc =
                    match *self {
                        DiffActivity::None => { 0usize }
                        DiffActivity::Const => { 1usize }
                        DiffActivity::Active => { 2usize }
                        DiffActivity::ActiveOnly => { 3usize }
                        DiffActivity::Dual => { 4usize }
                        DiffActivity::Dualv => { 5usize }
                        DiffActivity::DualOnly => { 6usize }
                        DiffActivity::DualvOnly => { 7usize }
                        DiffActivity::Duplicated => { 8usize }
                        DiffActivity::DuplicatedOnly => { 9usize }
                        DiffActivity::FakeActivitySize(ref __binding_0) => {
                            10usize
                        }
                    };
                ::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
                match *self {
                    DiffActivity::None => {}
                    DiffActivity::Const => {}
                    DiffActivity::Active => {}
                    DiffActivity::ActiveOnly => {}
                    DiffActivity::Dual => {}
                    DiffActivity::Dualv => {}
                    DiffActivity::DualOnly => {}
                    DiffActivity::DualvOnly => {}
                    DiffActivity::Duplicated => {}
                    DiffActivity::DuplicatedOnly => {}
                    DiffActivity::FakeActivitySize(ref __binding_0) => {
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_0,
                            __encoder);
                    }
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for DiffActivity {
            fn decode(__decoder: &mut __D) -> Self {
                match ::rustc_serialize::Decoder::read_u8(__decoder) as usize
                    {
                    0usize => { DiffActivity::None }
                    1usize => { DiffActivity::Const }
                    2usize => { DiffActivity::Active }
                    3usize => { DiffActivity::ActiveOnly }
                    4usize => { DiffActivity::Dual }
                    5usize => { DiffActivity::Dualv }
                    6usize => { DiffActivity::DualOnly }
                    7usize => { DiffActivity::DualvOnly }
                    8usize => { DiffActivity::Duplicated }
                    9usize => { DiffActivity::DuplicatedOnly }
                    10usize => {
                        DiffActivity::FakeActivitySize(::rustc_serialize::Decodable::decode(__decoder))
                    }
                    n => {
                        ::core::panicking::panic_fmt(format_args!("invalid enum variant tag while decoding `DiffActivity`, expected 0..11, actual {0}",
                                n));
                    }
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for DiffActivity {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        match self {
            DiffActivity::None =>
                ::core::fmt::Formatter::write_str(f, "None"),
            DiffActivity::Const =>
                ::core::fmt::Formatter::write_str(f, "Const"),
            DiffActivity::Active =>
                ::core::fmt::Formatter::write_str(f, "Active"),
            DiffActivity::ActiveOnly =>
                ::core::fmt::Formatter::write_str(f, "ActiveOnly"),
            DiffActivity::Dual =>
                ::core::fmt::Formatter::write_str(f, "Dual"),
            DiffActivity::Dualv =>
                ::core::fmt::Formatter::write_str(f, "Dualv"),
            DiffActivity::DualOnly =>
                ::core::fmt::Formatter::write_str(f, "DualOnly"),
            DiffActivity::DualvOnly =>
                ::core::fmt::Formatter::write_str(f, "DualvOnly"),
            DiffActivity::Duplicated =>
                ::core::fmt::Formatter::write_str(f, "Duplicated"),
            DiffActivity::DuplicatedOnly =>
                ::core::fmt::Formatter::write_str(f, "DuplicatedOnly"),
            DiffActivity::FakeActivitySize(__self_0) =>
                ::core::fmt::Formatter::debug_tuple_field1_finish(f,
                    "FakeActivitySize", &__self_0),
        }
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for DiffActivity where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                ::std::mem::discriminant(self).hash_stable(__hcx, __hasher);
                match *self {
                    DiffActivity::None => {}
                    DiffActivity::Const => {}
                    DiffActivity::Active => {}
                    DiffActivity::ActiveOnly => {}
                    DiffActivity::Dual => {}
                    DiffActivity::Dualv => {}
                    DiffActivity::DualOnly => {}
                    DiffActivity::DualvOnly => {}
                    DiffActivity::Duplicated => {}
                    DiffActivity::DuplicatedOnly => {}
                    DiffActivity::FakeActivitySize(ref __binding_0) => {
                        { __binding_0.hash_stable(__hcx, __hasher); }
                    }
                }
            }
        }
    };HashStable_Generic)]
46pub enum DiffActivity {
47    /// Implicit or Explicit () return type, so a special case of Const.
48    None,
49    /// Don't compute derivatives with respect to this input/output.
50    Const,
51    /// Reverse Mode, Compute derivatives for this scalar input/output.
52    Active,
53    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
54    /// the original return value.
55    ActiveOnly,
56    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
57    /// with it.
58    Dual,
59    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60    /// with it. It expects the shadow argument to be `width` times larger than the original
61    /// input/output.
62    Dualv,
63    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
64    /// with it. Drop the code which updates the original input/output for maximum performance.
65    DualOnly,
66    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
67    /// with it. Drop the code which updates the original input/output for maximum performance.
68    /// It expects the shadow argument to be `width` times larger than the original input/output.
69    DualvOnly,
70    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
71    Duplicated,
72    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
73    /// Drop the code which updates the original input for maximum performance.
74    DuplicatedOnly,
75    /// All Integers must be Const, but these are used to mark the integer which represents the
76    /// length of a slice/vec. This is used for safety checks on slices.
77    /// The integer (if given) specifies the size of the slice element in bytes.
78    FakeActivitySize(Option<u32>),
79}
80
81impl DiffActivity {
82    pub fn is_dual_or_const(&self) -> bool {
83        use DiffActivity::*;
84        #[allow(non_exhaustive_omitted_patterns)] match self {
    Dual | DualOnly | Dualv | DualvOnly | Const => true,
    _ => false,
}matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
85    }
86
87    pub fn all_activities() -> &'static [Symbol] {
88        &[
89            sym::None,
90            sym::Active,
91            sym::ActiveOnly,
92            sym::Const,
93            sym::Dual,
94            sym::Dualv,
95            sym::DualOnly,
96            sym::DualvOnly,
97            sym::Duplicated,
98            sym::DuplicatedOnly,
99        ]
100    }
101}
102
103impl DiffMode {
104    pub fn is_rev(&self) -> bool {
105        #[allow(non_exhaustive_omitted_patterns)] match self {
    DiffMode::Reverse => true,
    _ => false,
}matches!(self, DiffMode::Reverse)
106    }
107    pub fn is_fwd(&self) -> bool {
108        #[allow(non_exhaustive_omitted_patterns)] match self {
    DiffMode::Forward => true,
    _ => false,
}matches!(self, DiffMode::Forward)
109    }
110}
111
112impl Display for DiffMode {
113    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
114        match self {
115            DiffMode::Error => f.write_fmt(format_args!("Error"))write!(f, "Error"),
116            DiffMode::Source => f.write_fmt(format_args!("Source"))write!(f, "Source"),
117            DiffMode::Forward => f.write_fmt(format_args!("Forward"))write!(f, "Forward"),
118            DiffMode::Reverse => f.write_fmt(format_args!("Reverse"))write!(f, "Reverse"),
119        }
120    }
121}
122
123/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
124/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
125/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
126/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
127/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
128pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
129    if activity == DiffActivity::None {
130        // Only valid if primal returns (), but we can't check that here.
131        return true;
132    }
133    match mode {
134        DiffMode::Error => false,
135        DiffMode::Source => false,
136        DiffMode::Forward => activity.is_dual_or_const(),
137        DiffMode::Reverse => {
138            activity == DiffActivity::Const
139                || activity == DiffActivity::Active
140                || activity == DiffActivity::ActiveOnly
141        }
142    }
143}
144
145/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
146/// for the given argument, but we generally can't know the size of such a type.
147/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
148/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
149/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
150/// users here from marking scalars as Duplicated, due to type aliases.
151pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {
152    use DiffActivity::*;
153    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
154    // Dual variants also support all types.
155    if activity.is_dual_or_const() {
156        return true;
157    }
158    // FIXME(ZuseZ4) We should make this more robust to also
159    // handle type aliases. Once that is done, we can be more restrictive here.
160    if #[allow(non_exhaustive_omitted_patterns)] match activity {
    Active | ActiveOnly => true,
    _ => false,
}matches!(activity, Active | ActiveOnly) {
161        return true;
162    }
163    #[allow(non_exhaustive_omitted_patterns)] match ty.kind {
    TyKind::Ptr(_) | TyKind::Ref(..) => true,
    _ => false,
}matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
164        && #[allow(non_exhaustive_omitted_patterns)] match activity {
    Duplicated | DuplicatedOnly => true,
    _ => false,
}matches!(activity, Duplicated | DuplicatedOnly)
165}
166pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
167    use DiffActivity::*;
168    return match mode {
169        DiffMode::Error => false,
170        DiffMode::Source => false,
171        DiffMode::Forward => activity.is_dual_or_const(),
172        DiffMode::Reverse => {
173            #[allow(non_exhaustive_omitted_patterns)] match activity {
    Active | ActiveOnly | Duplicated | DuplicatedOnly | Const => true,
    _ => false,
}matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
174        }
175    };
176}
177
178impl Display for DiffActivity {
179    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
180        match self {
181            DiffActivity::None => f.write_fmt(format_args!("None"))write!(f, "None"),
182            DiffActivity::Const => f.write_fmt(format_args!("Const"))write!(f, "Const"),
183            DiffActivity::Active => f.write_fmt(format_args!("Active"))write!(f, "Active"),
184            DiffActivity::ActiveOnly => f.write_fmt(format_args!("ActiveOnly"))write!(f, "ActiveOnly"),
185            DiffActivity::Dual => f.write_fmt(format_args!("Dual"))write!(f, "Dual"),
186            DiffActivity::Dualv => f.write_fmt(format_args!("Dualv"))write!(f, "Dualv"),
187            DiffActivity::DualOnly => f.write_fmt(format_args!("DualOnly"))write!(f, "DualOnly"),
188            DiffActivity::DualvOnly => f.write_fmt(format_args!("DualvOnly"))write!(f, "DualvOnly"),
189            DiffActivity::Duplicated => f.write_fmt(format_args!("Duplicated"))write!(f, "Duplicated"),
190            DiffActivity::DuplicatedOnly => f.write_fmt(format_args!("DuplicatedOnly"))write!(f, "DuplicatedOnly"),
191            DiffActivity::FakeActivitySize(s) => f.write_fmt(format_args!("FakeActivitySize({0:?})", s))write!(f, "FakeActivitySize({:?})", s),
192        }
193    }
194}
195
196impl FromStr for DiffMode {
197    type Err = ();
198
199    fn from_str(s: &str) -> Result<DiffMode, ()> {
200        match s {
201            "Error" => Ok(DiffMode::Error),
202            "Source" => Ok(DiffMode::Source),
203            "Forward" => Ok(DiffMode::Forward),
204            "Reverse" => Ok(DiffMode::Reverse),
205            _ => Err(()),
206        }
207    }
208}
209impl FromStr for DiffActivity {
210    type Err = ();
211
212    fn from_str(s: &str) -> Result<DiffActivity, ()> {
213        match s {
214            "None" => Ok(DiffActivity::None),
215            "Active" => Ok(DiffActivity::Active),
216            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
217            "Const" => Ok(DiffActivity::Const),
218            "Dual" => Ok(DiffActivity::Dual),
219            "Dualv" => Ok(DiffActivity::Dualv),
220            "DualOnly" => Ok(DiffActivity::DualOnly),
221            "DualvOnly" => Ok(DiffActivity::DualvOnly),
222            "Duplicated" => Ok(DiffActivity::Duplicated),
223            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
224            _ => Err(()),
225        }
226    }
227}