Skip to main content

rustc_ast/expand/
autodiff_attrs.rs

1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an [`AutoDiffItem`] which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
5
6use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
8
9use crate::expand::typetree::TypeTree;
10use crate::expand::{Decodable, Encodable, HashStable_Generic};
11use crate::{Ty, TyKind};
12
13/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
14/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
15/// are a hack to support higher order derivatives. We need to compute first order derivatives
16/// before we compute second order derivatives, otherwise we would differentiate our placeholder
17/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
18/// as it's already done in the C++ and Julia frontend of Enzyme.
19///
20/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
21/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
22#[derive(#[automatically_derived]
impl ::core::clone::Clone for DiffMode {
    #[inline]
    fn clone(&self) -> DiffMode { *self }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for DiffMode { }Copy, #[automatically_derived]
impl ::core::cmp::Eq for DiffMode {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) {}
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for DiffMode {
    #[inline]
    fn eq(&self, other: &DiffMode) -> bool {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
        __self_discr == __arg1_discr
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for DiffMode {
            fn encode(&self, __encoder: &mut __E) {
                let disc =
                    match *self {
                        DiffMode::Error => { 0usize }
                        DiffMode::Source => { 1usize }
                        DiffMode::Forward => { 2usize }
                        DiffMode::Reverse => { 3usize }
                    };
                ::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
                match *self {
                    DiffMode::Error => {}
                    DiffMode::Source => {}
                    DiffMode::Forward => {}
                    DiffMode::Reverse => {}
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for DiffMode {
            fn decode(__decoder: &mut __D) -> Self {
                match ::rustc_serialize::Decoder::read_u8(__decoder) as usize
                    {
                    0usize => { DiffMode::Error }
                    1usize => { DiffMode::Source }
                    2usize => { DiffMode::Forward }
                    3usize => { DiffMode::Reverse }
                    n => {
                        ::core::panicking::panic_fmt(format_args!("invalid enum variant tag while decoding `DiffMode`, expected 0..4, actual {0}",
                                n));
                    }
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for DiffMode {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::write_str(f,
            match self {
                DiffMode::Error => "Error",
                DiffMode::Source => "Source",
                DiffMode::Forward => "Forward",
                DiffMode::Reverse => "Reverse",
            })
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for DiffMode where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                ::std::mem::discriminant(self).hash_stable(__hcx, __hasher);
                match *self {
                    DiffMode::Error => {}
                    DiffMode::Source => {}
                    DiffMode::Forward => {}
                    DiffMode::Reverse => {}
                }
            }
        }
    };HashStable_Generic)]
23pub enum DiffMode {
24    /// No autodiff is applied (used during error handling).
25    Error,
26    /// The primal function which we will differentiate.
27    Source,
28    /// The target function, to be created using forward mode AD.
29    Forward,
30    /// The target function, to be created using reverse mode AD.
31    Reverse,
32}
33
34/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
35/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
36/// we add to the previous shadow value. To not surprise users, we picked different names.
37/// Dual numbers is also a quite well known name for forward mode AD types.
38#[derive(#[automatically_derived]
impl ::core::clone::Clone for DiffActivity {
    #[inline]
    fn clone(&self) -> DiffActivity {
        let _: ::core::clone::AssertParamIsClone<Option<u32>>;
        *self
    }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for DiffActivity { }Copy, #[automatically_derived]
impl ::core::cmp::Eq for DiffActivity {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) {
        let _: ::core::cmp::AssertParamIsEq<Option<u32>>;
    }
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for DiffActivity {
    #[inline]
    fn eq(&self, other: &DiffActivity) -> bool {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
        __self_discr == __arg1_discr &&
            match (self, other) {
                (DiffActivity::FakeActivitySize(__self_0),
                    DiffActivity::FakeActivitySize(__arg1_0)) =>
                    __self_0 == __arg1_0,
                _ => true,
            }
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for DiffActivity {
            fn encode(&self, __encoder: &mut __E) {
                let disc =
                    match *self {
                        DiffActivity::None => { 0usize }
                        DiffActivity::Const => { 1usize }
                        DiffActivity::Active => { 2usize }
                        DiffActivity::ActiveOnly => { 3usize }
                        DiffActivity::Dual => { 4usize }
                        DiffActivity::Dualv => { 5usize }
                        DiffActivity::DualOnly => { 6usize }
                        DiffActivity::DualvOnly => { 7usize }
                        DiffActivity::Duplicated => { 8usize }
                        DiffActivity::DuplicatedOnly => { 9usize }
                        DiffActivity::FakeActivitySize(ref __binding_0) => {
                            10usize
                        }
                    };
                ::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
                match *self {
                    DiffActivity::None => {}
                    DiffActivity::Const => {}
                    DiffActivity::Active => {}
                    DiffActivity::ActiveOnly => {}
                    DiffActivity::Dual => {}
                    DiffActivity::Dualv => {}
                    DiffActivity::DualOnly => {}
                    DiffActivity::DualvOnly => {}
                    DiffActivity::Duplicated => {}
                    DiffActivity::DuplicatedOnly => {}
                    DiffActivity::FakeActivitySize(ref __binding_0) => {
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_0,
                            __encoder);
                    }
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for DiffActivity {
            fn decode(__decoder: &mut __D) -> Self {
                match ::rustc_serialize::Decoder::read_u8(__decoder) as usize
                    {
                    0usize => { DiffActivity::None }
                    1usize => { DiffActivity::Const }
                    2usize => { DiffActivity::Active }
                    3usize => { DiffActivity::ActiveOnly }
                    4usize => { DiffActivity::Dual }
                    5usize => { DiffActivity::Dualv }
                    6usize => { DiffActivity::DualOnly }
                    7usize => { DiffActivity::DualvOnly }
                    8usize => { DiffActivity::Duplicated }
                    9usize => { DiffActivity::DuplicatedOnly }
                    10usize => {
                        DiffActivity::FakeActivitySize(::rustc_serialize::Decodable::decode(__decoder))
                    }
                    n => {
                        ::core::panicking::panic_fmt(format_args!("invalid enum variant tag while decoding `DiffActivity`, expected 0..11, actual {0}",
                                n));
                    }
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for DiffActivity {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        match self {
            DiffActivity::None =>
                ::core::fmt::Formatter::write_str(f, "None"),
            DiffActivity::Const =>
                ::core::fmt::Formatter::write_str(f, "Const"),
            DiffActivity::Active =>
                ::core::fmt::Formatter::write_str(f, "Active"),
            DiffActivity::ActiveOnly =>
                ::core::fmt::Formatter::write_str(f, "ActiveOnly"),
            DiffActivity::Dual =>
                ::core::fmt::Formatter::write_str(f, "Dual"),
            DiffActivity::Dualv =>
                ::core::fmt::Formatter::write_str(f, "Dualv"),
            DiffActivity::DualOnly =>
                ::core::fmt::Formatter::write_str(f, "DualOnly"),
            DiffActivity::DualvOnly =>
                ::core::fmt::Formatter::write_str(f, "DualvOnly"),
            DiffActivity::Duplicated =>
                ::core::fmt::Formatter::write_str(f, "Duplicated"),
            DiffActivity::DuplicatedOnly =>
                ::core::fmt::Formatter::write_str(f, "DuplicatedOnly"),
            DiffActivity::FakeActivitySize(__self_0) =>
                ::core::fmt::Formatter::debug_tuple_field1_finish(f,
                    "FakeActivitySize", &__self_0),
        }
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for DiffActivity where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                ::std::mem::discriminant(self).hash_stable(__hcx, __hasher);
                match *self {
                    DiffActivity::None => {}
                    DiffActivity::Const => {}
                    DiffActivity::Active => {}
                    DiffActivity::ActiveOnly => {}
                    DiffActivity::Dual => {}
                    DiffActivity::Dualv => {}
                    DiffActivity::DualOnly => {}
                    DiffActivity::DualvOnly => {}
                    DiffActivity::Duplicated => {}
                    DiffActivity::DuplicatedOnly => {}
                    DiffActivity::FakeActivitySize(ref __binding_0) => {
                        { __binding_0.hash_stable(__hcx, __hasher); }
                    }
                }
            }
        }
    };HashStable_Generic)]
39pub enum DiffActivity {
40    /// Implicit or Explicit () return type, so a special case of Const.
41    None,
42    /// Don't compute derivatives with respect to this input/output.
43    Const,
44    /// Reverse Mode, Compute derivatives for this scalar input/output.
45    Active,
46    /// Reverse Mode, Compute derivatives for this scalar output, but don't compute
47    /// the original return value.
48    ActiveOnly,
49    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
50    /// with it.
51    Dual,
52    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
53    /// with it. It expects the shadow argument to be `width` times larger than the original
54    /// input/output.
55    Dualv,
56    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
57    /// with it. Drop the code which updates the original input/output for maximum performance.
58    DualOnly,
59    /// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60    /// with it. Drop the code which updates the original input/output for maximum performance.
61    /// It expects the shadow argument to be `width` times larger than the original input/output.
62    DualvOnly,
63    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
64    Duplicated,
65    /// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
66    /// Drop the code which updates the original input for maximum performance.
67    DuplicatedOnly,
68    /// All Integers must be Const, but these are used to mark the integer which represents the
69    /// length of a slice/vec. This is used for safety checks on slices.
70    /// The integer (if given) specifies the size of the slice element in bytes.
71    FakeActivitySize(Option<u32>),
72}
73
74impl DiffActivity {
75    pub fn is_dual_or_const(&self) -> bool {
76        use DiffActivity::*;
77        #[allow(non_exhaustive_omitted_patterns)] match self {
    Dual | DualOnly | Dualv | DualvOnly | Const => true,
    _ => false,
}matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)
78    }
79}
80/// We generate one of these structs for each `#[autodiff(...)]` attribute.
81#[derive(#[automatically_derived]
impl ::core::clone::Clone for AutoDiffItem {
    #[inline]
    fn clone(&self) -> AutoDiffItem {
        AutoDiffItem {
            source: ::core::clone::Clone::clone(&self.source),
            target: ::core::clone::Clone::clone(&self.target),
            attrs: ::core::clone::Clone::clone(&self.attrs),
            inputs: ::core::clone::Clone::clone(&self.inputs),
            output: ::core::clone::Clone::clone(&self.output),
        }
    }
}Clone, #[automatically_derived]
impl ::core::cmp::Eq for AutoDiffItem {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) {
        let _: ::core::cmp::AssertParamIsEq<String>;
        let _: ::core::cmp::AssertParamIsEq<AutoDiffAttrs>;
        let _: ::core::cmp::AssertParamIsEq<Vec<TypeTree>>;
        let _: ::core::cmp::AssertParamIsEq<TypeTree>;
    }
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for AutoDiffItem {
    #[inline]
    fn eq(&self, other: &AutoDiffItem) -> bool {
        self.source == other.source && self.target == other.target &&
                    self.attrs == other.attrs && self.inputs == other.inputs &&
            self.output == other.output
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for AutoDiffItem {
            fn encode(&self, __encoder: &mut __E) {
                match *self {
                    AutoDiffItem {
                        source: ref __binding_0,
                        target: ref __binding_1,
                        attrs: ref __binding_2,
                        inputs: ref __binding_3,
                        output: ref __binding_4 } => {
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_0,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_1,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_2,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_3,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_4,
                            __encoder);
                    }
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for AutoDiffItem {
            fn decode(__decoder: &mut __D) -> Self {
                AutoDiffItem {
                    source: ::rustc_serialize::Decodable::decode(__decoder),
                    target: ::rustc_serialize::Decodable::decode(__decoder),
                    attrs: ::rustc_serialize::Decodable::decode(__decoder),
                    inputs: ::rustc_serialize::Decodable::decode(__decoder),
                    output: ::rustc_serialize::Decodable::decode(__decoder),
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for AutoDiffItem {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_struct_field5_finish(f, "AutoDiffItem",
            "source", &self.source, "target", &self.target, "attrs",
            &self.attrs, "inputs", &self.inputs, "output", &&self.output)
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for AutoDiffItem where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                match *self {
                    AutoDiffItem {
                        source: ref __binding_0,
                        target: ref __binding_1,
                        attrs: ref __binding_2,
                        inputs: ref __binding_3,
                        output: ref __binding_4 } => {
                        { __binding_0.hash_stable(__hcx, __hasher); }
                        { __binding_1.hash_stable(__hcx, __hasher); }
                        { __binding_2.hash_stable(__hcx, __hasher); }
                        { __binding_3.hash_stable(__hcx, __hasher); }
                        { __binding_4.hash_stable(__hcx, __hasher); }
                    }
                }
            }
        }
    };HashStable_Generic)]
82pub struct AutoDiffItem {
83    /// The name of the function getting differentiated
84    pub source: String,
85    /// The name of the function being generated
86    pub target: String,
87    pub attrs: AutoDiffAttrs,
88    pub inputs: Vec<TypeTree>,
89    pub output: TypeTree,
90}
91
92#[derive(#[automatically_derived]
impl ::core::clone::Clone for AutoDiffAttrs {
    #[inline]
    fn clone(&self) -> AutoDiffAttrs {
        AutoDiffAttrs {
            mode: ::core::clone::Clone::clone(&self.mode),
            width: ::core::clone::Clone::clone(&self.width),
            ret_activity: ::core::clone::Clone::clone(&self.ret_activity),
            input_activity: ::core::clone::Clone::clone(&self.input_activity),
        }
    }
}Clone, #[automatically_derived]
impl ::core::cmp::Eq for AutoDiffAttrs {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) {
        let _: ::core::cmp::AssertParamIsEq<DiffMode>;
        let _: ::core::cmp::AssertParamIsEq<u32>;
        let _: ::core::cmp::AssertParamIsEq<DiffActivity>;
        let _: ::core::cmp::AssertParamIsEq<Vec<DiffActivity>>;
    }
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for AutoDiffAttrs {
    #[inline]
    fn eq(&self, other: &AutoDiffAttrs) -> bool {
        self.width == other.width && self.mode == other.mode &&
                self.ret_activity == other.ret_activity &&
            self.input_activity == other.input_activity
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for AutoDiffAttrs {
            fn encode(&self, __encoder: &mut __E) {
                match *self {
                    AutoDiffAttrs {
                        mode: ref __binding_0,
                        width: ref __binding_1,
                        ret_activity: ref __binding_2,
                        input_activity: ref __binding_3 } => {
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_0,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_1,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_2,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_3,
                            __encoder);
                    }
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for AutoDiffAttrs {
            fn decode(__decoder: &mut __D) -> Self {
                AutoDiffAttrs {
                    mode: ::rustc_serialize::Decodable::decode(__decoder),
                    width: ::rustc_serialize::Decodable::decode(__decoder),
                    ret_activity: ::rustc_serialize::Decodable::decode(__decoder),
                    input_activity: ::rustc_serialize::Decodable::decode(__decoder),
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for AutoDiffAttrs {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_struct_field4_finish(f, "AutoDiffAttrs",
            "mode", &self.mode, "width", &self.width, "ret_activity",
            &self.ret_activity, "input_activity", &&self.input_activity)
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for AutoDiffAttrs where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                match *self {
                    AutoDiffAttrs {
                        mode: ref __binding_0,
                        width: ref __binding_1,
                        ret_activity: ref __binding_2,
                        input_activity: ref __binding_3 } => {
                        { __binding_0.hash_stable(__hcx, __hasher); }
                        { __binding_1.hash_stable(__hcx, __hasher); }
                        { __binding_2.hash_stable(__hcx, __hasher); }
                        { __binding_3.hash_stable(__hcx, __hasher); }
                    }
                }
            }
        }
    };HashStable_Generic)]
93pub struct AutoDiffAttrs {
94    /// Conceptually either forward or reverse mode AD, as described in various autodiff papers and
95    /// e.g. in the [JAX
96    /// Documentation](https://jax.readthedocs.io/en/latest/_tutorials/advanced-autodiff.html#how-it-s-made-two-foundational-autodiff-functions).
97    pub mode: DiffMode,
98    /// A user-provided, batching width. If not given, we will default to 1 (no batching).
99    /// Calling a differentiated, non-batched function through a loop 100 times is equivalent to:
100    /// - Calling the function 50 times with a batch size of 2
101    /// - Calling the function 25 times with a batch size of 4,
102    /// etc. A batched function takes more (or longer) arguments, and might be able to benefit from
103    /// cache locality, better re-usal of primal values, and other optimizations.
104    /// We will (before LLVM's vectorizer runs) just generate most LLVM-IR instructions `width`
105    /// times, so this massively increases code size. As such, values like 1024 are unlikely to
106    /// work. We should consider limiting this to u8 or u16, but will leave it at u32 for
107    /// experiments for now and focus on documenting the implications of a large width.
108    pub width: u32,
109    pub ret_activity: DiffActivity,
110    pub input_activity: Vec<DiffActivity>,
111}
112
113impl AutoDiffAttrs {
114    pub fn has_primal_ret(&self) -> bool {
115        #[allow(non_exhaustive_omitted_patterns)] match self.ret_activity {
    DiffActivity::Active | DiffActivity::Dual => true,
    _ => false,
}matches!(self.ret_activity, DiffActivity::Active | DiffActivity::Dual)
116    }
117}
118
119impl DiffMode {
120    pub fn is_rev(&self) -> bool {
121        #[allow(non_exhaustive_omitted_patterns)] match self {
    DiffMode::Reverse => true,
    _ => false,
}matches!(self, DiffMode::Reverse)
122    }
123    pub fn is_fwd(&self) -> bool {
124        #[allow(non_exhaustive_omitted_patterns)] match self {
    DiffMode::Forward => true,
    _ => false,
}matches!(self, DiffMode::Forward)
125    }
126}
127
128impl Display for DiffMode {
129    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
130        match self {
131            DiffMode::Error => f.write_fmt(format_args!("Error"))write!(f, "Error"),
132            DiffMode::Source => f.write_fmt(format_args!("Source"))write!(f, "Source"),
133            DiffMode::Forward => f.write_fmt(format_args!("Forward"))write!(f, "Forward"),
134            DiffMode::Reverse => f.write_fmt(format_args!("Reverse"))write!(f, "Reverse"),
135        }
136    }
137}
138
139/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
140/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
141/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
142/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
143/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
144pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
145    if activity == DiffActivity::None {
146        // Only valid if primal returns (), but we can't check that here.
147        return true;
148    }
149    match mode {
150        DiffMode::Error => false,
151        DiffMode::Source => false,
152        DiffMode::Forward => activity.is_dual_or_const(),
153        DiffMode::Reverse => {
154            activity == DiffActivity::Const
155                || activity == DiffActivity::Active
156                || activity == DiffActivity::ActiveOnly
157        }
158    }
159}
160
161/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
162/// for the given argument, but we generally can't know the size of such a type.
163/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
164/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
165/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
166/// users here from marking scalars as Duplicated, due to type aliases.
167pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {
168    use DiffActivity::*;
169    // It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
170    // Dual variants also support all types.
171    if activity.is_dual_or_const() {
172        return true;
173    }
174    // FIXME(ZuseZ4) We should make this more robust to also
175    // handle type aliases. Once that is done, we can be more restrictive here.
176    if #[allow(non_exhaustive_omitted_patterns)] match activity {
    Active | ActiveOnly => true,
    _ => false,
}matches!(activity, Active | ActiveOnly) {
177        return true;
178    }
179    #[allow(non_exhaustive_omitted_patterns)] match ty.kind {
    TyKind::Ptr(_) | TyKind::Ref(..) => true,
    _ => false,
}matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))
180        && #[allow(non_exhaustive_omitted_patterns)] match activity {
    Duplicated | DuplicatedOnly => true,
    _ => false,
}matches!(activity, Duplicated | DuplicatedOnly)
181}
182pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
183    use DiffActivity::*;
184    return match mode {
185        DiffMode::Error => false,
186        DiffMode::Source => false,
187        DiffMode::Forward => activity.is_dual_or_const(),
188        DiffMode::Reverse => {
189            #[allow(non_exhaustive_omitted_patterns)] match activity {
    Active | ActiveOnly | Duplicated | DuplicatedOnly | Const => true,
    _ => false,
}matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)
190        }
191    };
192}
193
194impl Display for DiffActivity {
195    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
196        match self {
197            DiffActivity::None => f.write_fmt(format_args!("None"))write!(f, "None"),
198            DiffActivity::Const => f.write_fmt(format_args!("Const"))write!(f, "Const"),
199            DiffActivity::Active => f.write_fmt(format_args!("Active"))write!(f, "Active"),
200            DiffActivity::ActiveOnly => f.write_fmt(format_args!("ActiveOnly"))write!(f, "ActiveOnly"),
201            DiffActivity::Dual => f.write_fmt(format_args!("Dual"))write!(f, "Dual"),
202            DiffActivity::Dualv => f.write_fmt(format_args!("Dualv"))write!(f, "Dualv"),
203            DiffActivity::DualOnly => f.write_fmt(format_args!("DualOnly"))write!(f, "DualOnly"),
204            DiffActivity::DualvOnly => f.write_fmt(format_args!("DualvOnly"))write!(f, "DualvOnly"),
205            DiffActivity::Duplicated => f.write_fmt(format_args!("Duplicated"))write!(f, "Duplicated"),
206            DiffActivity::DuplicatedOnly => f.write_fmt(format_args!("DuplicatedOnly"))write!(f, "DuplicatedOnly"),
207            DiffActivity::FakeActivitySize(s) => f.write_fmt(format_args!("FakeActivitySize({0:?})", s))write!(f, "FakeActivitySize({:?})", s),
208        }
209    }
210}
211
212impl FromStr for DiffMode {
213    type Err = ();
214
215    fn from_str(s: &str) -> Result<DiffMode, ()> {
216        match s {
217            "Error" => Ok(DiffMode::Error),
218            "Source" => Ok(DiffMode::Source),
219            "Forward" => Ok(DiffMode::Forward),
220            "Reverse" => Ok(DiffMode::Reverse),
221            _ => Err(()),
222        }
223    }
224}
225impl FromStr for DiffActivity {
226    type Err = ();
227
228    fn from_str(s: &str) -> Result<DiffActivity, ()> {
229        match s {
230            "None" => Ok(DiffActivity::None),
231            "Active" => Ok(DiffActivity::Active),
232            "ActiveOnly" => Ok(DiffActivity::ActiveOnly),
233            "Const" => Ok(DiffActivity::Const),
234            "Dual" => Ok(DiffActivity::Dual),
235            "Dualv" => Ok(DiffActivity::Dualv),
236            "DualOnly" => Ok(DiffActivity::DualOnly),
237            "DualvOnly" => Ok(DiffActivity::DualvOnly),
238            "Duplicated" => Ok(DiffActivity::Duplicated),
239            "DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
240            _ => Err(()),
241        }
242    }
243}
244
245impl AutoDiffAttrs {
246    pub fn has_ret_activity(&self) -> bool {
247        self.ret_activity != DiffActivity::None
248    }
249    pub fn has_active_only_ret(&self) -> bool {
250        self.ret_activity == DiffActivity::ActiveOnly
251    }
252
253    pub const fn error() -> Self {
254        AutoDiffAttrs {
255            mode: DiffMode::Error,
256            width: 0,
257            ret_activity: DiffActivity::None,
258            input_activity: Vec::new(),
259        }
260    }
261    pub fn source() -> Self {
262        AutoDiffAttrs {
263            mode: DiffMode::Source,
264            width: 0,
265            ret_activity: DiffActivity::None,
266            input_activity: Vec::new(),
267        }
268    }
269
270    pub fn is_active(&self) -> bool {
271        self.mode != DiffMode::Error
272    }
273
274    pub fn is_source(&self) -> bool {
275        self.mode == DiffMode::Source
276    }
277    pub fn apply_autodiff(&self) -> bool {
278        !#[allow(non_exhaustive_omitted_patterns)] match self.mode {
    DiffMode::Error | DiffMode::Source => true,
    _ => false,
}matches!(self.mode, DiffMode::Error | DiffMode::Source)
279    }
280
281    pub fn into_item(
282        self,
283        source: String,
284        target: String,
285        inputs: Vec<TypeTree>,
286        output: TypeTree,
287    ) -> AutoDiffItem {
288        AutoDiffItem { source, target, inputs, output, attrs: self }
289    }
290}
291
292impl fmt::Display for AutoDiffItem {
293    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294        f.write_fmt(format_args!("Differentiating {0} -> {1}", self.source,
        self.target))write!(f, "Differentiating {} -> {}", self.source, self.target)?;
295        f.write_fmt(format_args!(" with attributes: {0:?}", self.attrs))write!(f, " with attributes: {:?}", self.attrs)?;
296        f.write_fmt(format_args!(" with inputs: {0:?}", self.inputs))write!(f, " with inputs: {:?}", self.inputs)?;
297        f.write_fmt(format_args!(" with output: {0:?}", self.output))write!(f, " with output: {:?}", self.output)
298    }
299}