1//! This crate handles the user facing autodiff macro. For each `#[autodiff(...)]` attribute,
2//! we create an `RustcAutodiff` which contains the source and target function names. The source
3//! is the function to which the autodiff attribute is applied, and the target is the function
4//! getting generated by us (with a name given by the user as the first autodiff arg).
56use std::fmt::{self, Display, Formatter};
7use std::str::FromStr;
89use rustc_span::{Symbol, sym};
1011use crate::expand::{Decodable, Encodable, HashStable_Generic};
12use crate::{Ty, TyKind};
1314/// Forward and Reverse Mode are well known names for automatic differentiation implementations.
15/// Enzyme does support both, but with different semantics, see DiffActivity. The First variants
16/// are a hack to support higher order derivatives. We need to compute first order derivatives
17/// before we compute second order derivatives, otherwise we would differentiate our placeholder
18/// functions. The proper solution is to recognize and resolve this DAG of autodiff invocations,
19/// as it's already done in the C++ and Julia frontend of Enzyme.
20///
21/// Documentation for using [reverse](https://enzyme.mit.edu/rust/rev.html) and
22/// [forward](https://enzyme.mit.edu/rust/fwd.html) mode is available online.
23#[derive(#[automatically_derived]
impl ::core::clone::Clone for DiffMode {
#[inline]
fn clone(&self) -> DiffMode { *self }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for DiffMode { }Copy, #[automatically_derived]
impl ::core::cmp::Eq for DiffMode {
#[doc(hidden)]
#[coverage(off)]
fn assert_fields_are_eq(&self) {}
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for DiffMode {
#[inline]
fn eq(&self, other: &DiffMode) -> bool {
let __self_discr = ::core::intrinsics::discriminant_value(self);
let __arg1_discr = ::core::intrinsics::discriminant_value(other);
__self_discr == __arg1_discr
}
}PartialEq, const _: () =
{
impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
for DiffMode {
fn encode(&self, __encoder: &mut __E) {
let disc =
match *self {
DiffMode::Error => { 0usize }
DiffMode::Source => { 1usize }
DiffMode::Forward => { 2usize }
DiffMode::Reverse => { 3usize }
};
::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
match *self {
DiffMode::Error => {}
DiffMode::Source => {}
DiffMode::Forward => {}
DiffMode::Reverse => {}
}
}
}
};Encodable, const _: () =
{
impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
for DiffMode {
fn decode(__decoder: &mut __D) -> Self {
match ::rustc_serialize::Decoder::read_u8(__decoder) as usize
{
0usize => { DiffMode::Error }
1usize => { DiffMode::Source }
2usize => { DiffMode::Forward }
3usize => { DiffMode::Reverse }
n => {
::core::panicking::panic_fmt(format_args!("invalid enum variant tag while decoding `DiffMode`, expected 0..4, actual {0}",
n));
}
}
}
}
};Decodable, #[automatically_derived]
impl ::core::fmt::Debug for DiffMode {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
::core::fmt::Formatter::write_str(f,
match self {
DiffMode::Error => "Error",
DiffMode::Source => "Source",
DiffMode::Forward => "Forward",
DiffMode::Reverse => "Reverse",
})
}
}Debug, const _: () =
{
impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
for DiffMode where __CTX: crate::HashStableContext {
#[inline]
fn hash_stable(&self, __hcx: &mut __CTX,
__hasher:
&mut ::rustc_data_structures::stable_hasher::StableHasher) {
::std::mem::discriminant(self).hash_stable(__hcx, __hasher);
match *self {
DiffMode::Error => {}
DiffMode::Source => {}
DiffMode::Forward => {}
DiffMode::Reverse => {}
}
}
}
};HashStable_Generic)]
24pub enum DiffMode {
25/// No autodiff is applied (used during error handling).
26Error,
27/// The primal function which we will differentiate.
28Source,
29/// The target function, to be created using forward mode AD.
30Forward,
31/// The target function, to be created using reverse mode AD.
32Reverse,
33}
3435impl DiffMode {
36pub fn all_modes() -> &'static [Symbol] {
37&[sym::Source, sym::Forward, sym::Reverse]
38 }
39}
4041/// Dual and Duplicated (and their Only variants) are getting lowered to the same Enzyme Activity.
42/// However, under forward mode we overwrite the previous shadow value, while for reverse mode
43/// we add to the previous shadow value. To not surprise users, we picked different names.
44/// Dual numbers is also a quite well known name for forward mode AD types.
45#[derive(#[automatically_derived]
impl ::core::clone::Clone for DiffActivity {
#[inline]
fn clone(&self) -> DiffActivity {
let _: ::core::clone::AssertParamIsClone<Option<u32>>;
*self
}
}Clone, #[automatically_derived]
impl ::core::marker::Copy for DiffActivity { }Copy, #[automatically_derived]
impl ::core::cmp::Eq for DiffActivity {
#[doc(hidden)]
#[coverage(off)]
fn assert_fields_are_eq(&self) {
let _: ::core::cmp::AssertParamIsEq<Option<u32>>;
}
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for DiffActivity {
#[inline]
fn eq(&self, other: &DiffActivity) -> bool {
let __self_discr = ::core::intrinsics::discriminant_value(self);
let __arg1_discr = ::core::intrinsics::discriminant_value(other);
__self_discr == __arg1_discr &&
match (self, other) {
(DiffActivity::FakeActivitySize(__self_0),
DiffActivity::FakeActivitySize(__arg1_0)) =>
__self_0 == __arg1_0,
_ => true,
}
}
}PartialEq, const _: () =
{
impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
for DiffActivity {
fn encode(&self, __encoder: &mut __E) {
let disc =
match *self {
DiffActivity::None => { 0usize }
DiffActivity::Const => { 1usize }
DiffActivity::Active => { 2usize }
DiffActivity::ActiveOnly => { 3usize }
DiffActivity::Dual => { 4usize }
DiffActivity::Dualv => { 5usize }
DiffActivity::DualOnly => { 6usize }
DiffActivity::DualvOnly => { 7usize }
DiffActivity::Duplicated => { 8usize }
DiffActivity::DuplicatedOnly => { 9usize }
DiffActivity::FakeActivitySize(ref __binding_0) => {
10usize
}
};
::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
match *self {
DiffActivity::None => {}
DiffActivity::Const => {}
DiffActivity::Active => {}
DiffActivity::ActiveOnly => {}
DiffActivity::Dual => {}
DiffActivity::Dualv => {}
DiffActivity::DualOnly => {}
DiffActivity::DualvOnly => {}
DiffActivity::Duplicated => {}
DiffActivity::DuplicatedOnly => {}
DiffActivity::FakeActivitySize(ref __binding_0) => {
::rustc_serialize::Encodable::<__E>::encode(__binding_0,
__encoder);
}
}
}
}
};Encodable, const _: () =
{
impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
for DiffActivity {
fn decode(__decoder: &mut __D) -> Self {
match ::rustc_serialize::Decoder::read_u8(__decoder) as usize
{
0usize => { DiffActivity::None }
1usize => { DiffActivity::Const }
2usize => { DiffActivity::Active }
3usize => { DiffActivity::ActiveOnly }
4usize => { DiffActivity::Dual }
5usize => { DiffActivity::Dualv }
6usize => { DiffActivity::DualOnly }
7usize => { DiffActivity::DualvOnly }
8usize => { DiffActivity::Duplicated }
9usize => { DiffActivity::DuplicatedOnly }
10usize => {
DiffActivity::FakeActivitySize(::rustc_serialize::Decodable::decode(__decoder))
}
n => {
::core::panicking::panic_fmt(format_args!("invalid enum variant tag while decoding `DiffActivity`, expected 0..11, actual {0}",
n));
}
}
}
}
};Decodable, #[automatically_derived]
impl ::core::fmt::Debug for DiffActivity {
#[inline]
fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
match self {
DiffActivity::None =>
::core::fmt::Formatter::write_str(f, "None"),
DiffActivity::Const =>
::core::fmt::Formatter::write_str(f, "Const"),
DiffActivity::Active =>
::core::fmt::Formatter::write_str(f, "Active"),
DiffActivity::ActiveOnly =>
::core::fmt::Formatter::write_str(f, "ActiveOnly"),
DiffActivity::Dual =>
::core::fmt::Formatter::write_str(f, "Dual"),
DiffActivity::Dualv =>
::core::fmt::Formatter::write_str(f, "Dualv"),
DiffActivity::DualOnly =>
::core::fmt::Formatter::write_str(f, "DualOnly"),
DiffActivity::DualvOnly =>
::core::fmt::Formatter::write_str(f, "DualvOnly"),
DiffActivity::Duplicated =>
::core::fmt::Formatter::write_str(f, "Duplicated"),
DiffActivity::DuplicatedOnly =>
::core::fmt::Formatter::write_str(f, "DuplicatedOnly"),
DiffActivity::FakeActivitySize(__self_0) =>
::core::fmt::Formatter::debug_tuple_field1_finish(f,
"FakeActivitySize", &__self_0),
}
}
}Debug, const _: () =
{
impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
for DiffActivity where __CTX: crate::HashStableContext {
#[inline]
fn hash_stable(&self, __hcx: &mut __CTX,
__hasher:
&mut ::rustc_data_structures::stable_hasher::StableHasher) {
::std::mem::discriminant(self).hash_stable(__hcx, __hasher);
match *self {
DiffActivity::None => {}
DiffActivity::Const => {}
DiffActivity::Active => {}
DiffActivity::ActiveOnly => {}
DiffActivity::Dual => {}
DiffActivity::Dualv => {}
DiffActivity::DualOnly => {}
DiffActivity::DualvOnly => {}
DiffActivity::Duplicated => {}
DiffActivity::DuplicatedOnly => {}
DiffActivity::FakeActivitySize(ref __binding_0) => {
{ __binding_0.hash_stable(__hcx, __hasher); }
}
}
}
}
};HashStable_Generic)]
46pub enum DiffActivity {
47/// Implicit or Explicit () return type, so a special case of Const.
48None,
49/// Don't compute derivatives with respect to this input/output.
50Const,
51/// Reverse Mode, Compute derivatives for this scalar input/output.
52Active,
53/// Reverse Mode, Compute derivatives for this scalar output, but don't compute
54 /// the original return value.
55ActiveOnly,
56/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
57 /// with it.
58Dual,
59/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
60 /// with it. It expects the shadow argument to be `width` times larger than the original
61 /// input/output.
62Dualv,
63/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
64 /// with it. Drop the code which updates the original input/output for maximum performance.
65DualOnly,
66/// Forward Mode, Compute derivatives for this input/output and *overwrite* the shadow argument
67 /// with it. Drop the code which updates the original input/output for maximum performance.
68 /// It expects the shadow argument to be `width` times larger than the original input/output.
69DualvOnly,
70/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
71Duplicated,
72/// Reverse Mode, Compute derivatives for this &T or *T input and *add* it to the shadow argument.
73 /// Drop the code which updates the original input for maximum performance.
74DuplicatedOnly,
75/// All Integers must be Const, but these are used to mark the integer which represents the
76 /// length of a slice/vec. This is used for safety checks on slices.
77 /// The integer (if given) specifies the size of the slice element in bytes.
78FakeActivitySize(Option<u32>),
79}
8081impl DiffActivity {
82pub fn is_dual_or_const(&self) -> bool {
83use DiffActivity::*;
84#[allow(non_exhaustive_omitted_patterns)] match self {
Dual | DualOnly | Dualv | DualvOnly | Const => true,
_ => false,
}matches!(self, |Dual| DualOnly | Dualv | DualvOnly | Const)85 }
8687pub fn all_activities() -> &'static [Symbol] {
88&[
89 sym::None,
90 sym::Active,
91 sym::ActiveOnly,
92 sym::Const,
93 sym::Dual,
94 sym::Dualv,
95 sym::DualOnly,
96 sym::DualvOnly,
97 sym::Duplicated,
98 sym::DuplicatedOnly,
99 ]
100 }
101}
102103impl DiffMode {
104pub fn is_rev(&self) -> bool {
105#[allow(non_exhaustive_omitted_patterns)] match self {
DiffMode::Reverse => true,
_ => false,
}matches!(self, DiffMode::Reverse)106 }
107pub fn is_fwd(&self) -> bool {
108#[allow(non_exhaustive_omitted_patterns)] match self {
DiffMode::Forward => true,
_ => false,
}matches!(self, DiffMode::Forward)109 }
110}
111112impl Displayfor DiffMode {
113fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
114match self {
115 DiffMode::Error => f.write_fmt(format_args!("Error"))write!(f, "Error"),
116 DiffMode::Source => f.write_fmt(format_args!("Source"))write!(f, "Source"),
117 DiffMode::Forward => f.write_fmt(format_args!("Forward"))write!(f, "Forward"),
118 DiffMode::Reverse => f.write_fmt(format_args!("Reverse"))write!(f, "Reverse"),
119 }
120 }
121}
122123/// Active(Only) is valid in reverse-mode AD for scalar float returns (f16/f32/...).
124/// Dual(Only) is valid in forward-mode AD for scalar float returns (f16/f32/...).
125/// Const is valid for all cases and means that we don't compute derivatives wrt. this output.
126/// That usually means we have a &mut or *mut T output and compute derivatives wrt. that arg,
127/// but this is too complex to verify here. Also it's just a logic error if users get this wrong.
128pub fn valid_ret_activity(mode: DiffMode, activity: DiffActivity) -> bool {
129if activity == DiffActivity::None {
130// Only valid if primal returns (), but we can't check that here.
131return true;
132 }
133match mode {
134 DiffMode::Error => false,
135 DiffMode::Source => false,
136 DiffMode::Forward => activity.is_dual_or_const(),
137 DiffMode::Reverse => {
138activity == DiffActivity::Const139 || activity == DiffActivity::Active140 || activity == DiffActivity::ActiveOnly141 }
142 }
143}
144145/// For indirections (ptr/ref) we can't use Active, since Active allocates a shadow value
146/// for the given argument, but we generally can't know the size of such a type.
147/// For scalar types (f16/f32/f64/f128) we can use Active and we can't use Duplicated,
148/// since Duplicated expects a mutable ref/ptr and we would thus end up with a shadow value
149/// who is an indirect type, which doesn't match the primal scalar type. We can't prevent
150/// users here from marking scalars as Duplicated, due to type aliases.
151pub fn valid_ty_for_activity(ty: &Box<Ty>, activity: DiffActivity) -> bool {
152use DiffActivity::*;
153// It's always allowed to mark something as Const, since we won't compute derivatives wrt. it.
154 // Dual variants also support all types.
155if activity.is_dual_or_const() {
156return true;
157 }
158// FIXME(ZuseZ4) We should make this more robust to also
159 // handle type aliases. Once that is done, we can be more restrictive here.
160if #[allow(non_exhaustive_omitted_patterns)] match activity {
Active | ActiveOnly => true,
_ => false,
}matches!(activity, Active | ActiveOnly) {
161return true;
162 }
163#[allow(non_exhaustive_omitted_patterns)] match ty.kind {
TyKind::Ptr(_) | TyKind::Ref(..) => true,
_ => false,
}matches!(ty.kind, TyKind::Ptr(_) | TyKind::Ref(..))164 && #[allow(non_exhaustive_omitted_patterns)] match activity {
Duplicated | DuplicatedOnly => true,
_ => false,
}matches!(activity, Duplicated | DuplicatedOnly)165}
166pub fn valid_input_activity(mode: DiffMode, activity: DiffActivity) -> bool {
167use DiffActivity::*;
168return match mode {
169 DiffMode::Error => false,
170 DiffMode::Source => false,
171 DiffMode::Forward => activity.is_dual_or_const(),
172 DiffMode::Reverse => {
173#[allow(non_exhaustive_omitted_patterns)] match activity {
Active | ActiveOnly | Duplicated | DuplicatedOnly | Const => true,
_ => false,
}matches!(activity, Active | ActiveOnly | Duplicated | DuplicatedOnly | Const)174 }
175 };
176}
177178impl Displayfor DiffActivity {
179fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
180match self {
181 DiffActivity::None => f.write_fmt(format_args!("None"))write!(f, "None"),
182 DiffActivity::Const => f.write_fmt(format_args!("Const"))write!(f, "Const"),
183 DiffActivity::Active => f.write_fmt(format_args!("Active"))write!(f, "Active"),
184 DiffActivity::ActiveOnly => f.write_fmt(format_args!("ActiveOnly"))write!(f, "ActiveOnly"),
185 DiffActivity::Dual => f.write_fmt(format_args!("Dual"))write!(f, "Dual"),
186 DiffActivity::Dualv => f.write_fmt(format_args!("Dualv"))write!(f, "Dualv"),
187 DiffActivity::DualOnly => f.write_fmt(format_args!("DualOnly"))write!(f, "DualOnly"),
188 DiffActivity::DualvOnly => f.write_fmt(format_args!("DualvOnly"))write!(f, "DualvOnly"),
189 DiffActivity::Duplicated => f.write_fmt(format_args!("Duplicated"))write!(f, "Duplicated"),
190 DiffActivity::DuplicatedOnly => f.write_fmt(format_args!("DuplicatedOnly"))write!(f, "DuplicatedOnly"),
191 DiffActivity::FakeActivitySize(s) => f.write_fmt(format_args!("FakeActivitySize({0:?})", s))write!(f, "FakeActivitySize({:?})", s),
192 }
193 }
194}
195196impl FromStrfor DiffMode {
197type Err = ();
198199fn from_str(s: &str) -> Result<DiffMode, ()> {
200match s {
201"Error" => Ok(DiffMode::Error),
202"Source" => Ok(DiffMode::Source),
203"Forward" => Ok(DiffMode::Forward),
204"Reverse" => Ok(DiffMode::Reverse),
205_ => Err(()),
206 }
207 }
208}
209impl FromStrfor DiffActivity {
210type Err = ();
211212fn from_str(s: &str) -> Result<DiffActivity, ()> {
213match s {
214"None" => Ok(DiffActivity::None),
215"Active" => Ok(DiffActivity::Active),
216"ActiveOnly" => Ok(DiffActivity::ActiveOnly),
217"Const" => Ok(DiffActivity::Const),
218"Dual" => Ok(DiffActivity::Dual),
219"Dualv" => Ok(DiffActivity::Dualv),
220"DualOnly" => Ok(DiffActivity::DualOnly),
221"DualvOnly" => Ok(DiffActivity::DualvOnly),
222"Duplicated" => Ok(DiffActivity::Duplicated),
223"DuplicatedOnly" => Ok(DiffActivity::DuplicatedOnly),
224_ => Err(()),
225 }
226 }
227}