rustc_ast/expand/
typetree.rs

1//! This module contains the definition of the `TypeTree` and `Type` structs.
2//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
3//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
4//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
5//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
6//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
7//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
8//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
9//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
10//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
11//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
12//! will be only pointers, if you dereference these new pointers they will point to array of floats.
13//! Generally, it allows byte-specific descriptions.
14//! FIXME: This description might be partly inaccurate and should be extended, along with
15//! adding documentation to the corresponding Enzyme core code.
16//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
17//! provide typetree information.
18//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
19//! representations of some types might not be accurate. For example a vector of floats might be
20//! represented as a vector of u8s in MIR in some cases.
21
22use std::fmt;
23
24use crate::expand::{Decodable, Encodable, HashStable_Generic};
25
26#[derive(#[automatically_derived]
impl ::core::clone::Clone for Kind {
    #[inline]
    fn clone(&self) -> Kind { *self }
}Clone, #[automatically_derived]
impl ::core::marker::Copy for Kind { }Copy, #[automatically_derived]
impl ::core::cmp::Eq for Kind {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) -> () {}
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for Kind {
    #[inline]
    fn eq(&self, other: &Kind) -> bool {
        let __self_discr = ::core::intrinsics::discriminant_value(self);
        let __arg1_discr = ::core::intrinsics::discriminant_value(other);
        __self_discr == __arg1_discr
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for Kind {
            fn encode(&self, __encoder: &mut __E) {
                let disc =
                    match *self {
                        Kind::Anything => { 0usize }
                        Kind::Integer => { 1usize }
                        Kind::Pointer => { 2usize }
                        Kind::Half => { 3usize }
                        Kind::Float => { 4usize }
                        Kind::Double => { 5usize }
                        Kind::F128 => { 6usize }
                        Kind::Unknown => { 7usize }
                    };
                ::rustc_serialize::Encoder::emit_u8(__encoder, disc as u8);
                match *self {
                    Kind::Anything => {}
                    Kind::Integer => {}
                    Kind::Pointer => {}
                    Kind::Half => {}
                    Kind::Float => {}
                    Kind::Double => {}
                    Kind::F128 => {}
                    Kind::Unknown => {}
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for Kind {
            fn decode(__decoder: &mut __D) -> Self {
                match ::rustc_serialize::Decoder::read_u8(__decoder) as usize
                    {
                    0usize => { Kind::Anything }
                    1usize => { Kind::Integer }
                    2usize => { Kind::Pointer }
                    3usize => { Kind::Half }
                    4usize => { Kind::Float }
                    5usize => { Kind::Double }
                    6usize => { Kind::F128 }
                    7usize => { Kind::Unknown }
                    n => {
                        ::core::panicking::panic_fmt(format_args!("invalid enum variant tag while decoding `Kind`, expected 0..8, actual {0}",
                                n));
                    }
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for Kind {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::write_str(f,
            match self {
                Kind::Anything => "Anything",
                Kind::Integer => "Integer",
                Kind::Pointer => "Pointer",
                Kind::Half => "Half",
                Kind::Float => "Float",
                Kind::Double => "Double",
                Kind::F128 => "F128",
                Kind::Unknown => "Unknown",
            })
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for Kind where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                ::std::mem::discriminant(self).hash_stable(__hcx, __hasher);
                match *self {
                    Kind::Anything => {}
                    Kind::Integer => {}
                    Kind::Pointer => {}
                    Kind::Half => {}
                    Kind::Float => {}
                    Kind::Double => {}
                    Kind::F128 => {}
                    Kind::Unknown => {}
                }
            }
        }
    };HashStable_Generic)]
27pub enum Kind {
28    Anything,
29    Integer,
30    Pointer,
31    Half,
32    Float,
33    Double,
34    F128,
35    Unknown,
36}
37
38#[derive(#[automatically_derived]
impl ::core::clone::Clone for TypeTree {
    #[inline]
    fn clone(&self) -> TypeTree {
        TypeTree(::core::clone::Clone::clone(&self.0))
    }
}Clone, #[automatically_derived]
impl ::core::cmp::Eq for TypeTree {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) -> () {
        let _: ::core::cmp::AssertParamIsEq<Vec<Type>>;
    }
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for TypeTree {
    #[inline]
    fn eq(&self, other: &TypeTree) -> bool { self.0 == other.0 }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for TypeTree {
            fn encode(&self, __encoder: &mut __E) {
                match *self {
                    TypeTree(ref __binding_0) => {
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_0,
                            __encoder);
                    }
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for TypeTree {
            fn decode(__decoder: &mut __D) -> Self {
                TypeTree(::rustc_serialize::Decodable::decode(__decoder))
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for TypeTree {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_tuple_field1_finish(f, "TypeTree",
            &&self.0)
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for TypeTree where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                match *self {
                    TypeTree(ref __binding_0) => {
                        { __binding_0.hash_stable(__hcx, __hasher); }
                    }
                }
            }
        }
    };HashStable_Generic)]
39pub struct TypeTree(pub Vec<Type>);
40
41impl TypeTree {
42    pub fn new() -> Self {
43        Self(Vec::new())
44    }
45    pub fn all_ints() -> Self {
46        Self(<[_]>::into_vec(::alloc::boxed::box_new([Type {
                    offset: -1,
                    size: 1,
                    kind: Kind::Integer,
                    child: TypeTree::new(),
                }]))vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
47    }
48    pub fn int(size: usize) -> Self {
49        let mut ints = Vec::with_capacity(size);
50        for i in 0..size {
51            ints.push(Type {
52                offset: i as isize,
53                size: 1,
54                kind: Kind::Integer,
55                child: TypeTree::new(),
56            });
57        }
58        Self(ints)
59    }
60}
61
62#[derive(#[automatically_derived]
impl ::core::clone::Clone for FncTree {
    #[inline]
    fn clone(&self) -> FncTree {
        FncTree {
            args: ::core::clone::Clone::clone(&self.args),
            ret: ::core::clone::Clone::clone(&self.ret),
        }
    }
}Clone, #[automatically_derived]
impl ::core::cmp::Eq for FncTree {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) -> () {
        let _: ::core::cmp::AssertParamIsEq<Vec<TypeTree>>;
        let _: ::core::cmp::AssertParamIsEq<TypeTree>;
    }
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for FncTree {
    #[inline]
    fn eq(&self, other: &FncTree) -> bool {
        self.args == other.args && self.ret == other.ret
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for FncTree {
            fn encode(&self, __encoder: &mut __E) {
                match *self {
                    FncTree { args: ref __binding_0, ret: ref __binding_1 } => {
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_0,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_1,
                            __encoder);
                    }
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for FncTree {
            fn decode(__decoder: &mut __D) -> Self {
                FncTree {
                    args: ::rustc_serialize::Decodable::decode(__decoder),
                    ret: ::rustc_serialize::Decodable::decode(__decoder),
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for FncTree {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_struct_field2_finish(f, "FncTree",
            "args", &self.args, "ret", &&self.ret)
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for FncTree where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                match *self {
                    FncTree { args: ref __binding_0, ret: ref __binding_1 } => {
                        { __binding_0.hash_stable(__hcx, __hasher); }
                        { __binding_1.hash_stable(__hcx, __hasher); }
                    }
                }
            }
        }
    };HashStable_Generic)]
63pub struct FncTree {
64    pub args: Vec<TypeTree>,
65    pub ret: TypeTree,
66}
67
68#[derive(#[automatically_derived]
impl ::core::clone::Clone for Type {
    #[inline]
    fn clone(&self) -> Type {
        Type {
            offset: ::core::clone::Clone::clone(&self.offset),
            size: ::core::clone::Clone::clone(&self.size),
            kind: ::core::clone::Clone::clone(&self.kind),
            child: ::core::clone::Clone::clone(&self.child),
        }
    }
}Clone, #[automatically_derived]
impl ::core::cmp::Eq for Type {
    #[inline]
    #[doc(hidden)]
    #[coverage(off)]
    fn assert_receiver_is_total_eq(&self) -> () {
        let _: ::core::cmp::AssertParamIsEq<isize>;
        let _: ::core::cmp::AssertParamIsEq<usize>;
        let _: ::core::cmp::AssertParamIsEq<Kind>;
        let _: ::core::cmp::AssertParamIsEq<TypeTree>;
    }
}Eq, #[automatically_derived]
impl ::core::cmp::PartialEq for Type {
    #[inline]
    fn eq(&self, other: &Type) -> bool {
        self.offset == other.offset && self.size == other.size &&
                self.kind == other.kind && self.child == other.child
    }
}PartialEq, const _: () =
    {
        impl<__E: ::rustc_span::SpanEncoder> ::rustc_serialize::Encodable<__E>
            for Type {
            fn encode(&self, __encoder: &mut __E) {
                match *self {
                    Type {
                        offset: ref __binding_0,
                        size: ref __binding_1,
                        kind: ref __binding_2,
                        child: ref __binding_3 } => {
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_0,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_1,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_2,
                            __encoder);
                        ::rustc_serialize::Encodable::<__E>::encode(__binding_3,
                            __encoder);
                    }
                }
            }
        }
    };Encodable, const _: () =
    {
        impl<__D: ::rustc_span::SpanDecoder> ::rustc_serialize::Decodable<__D>
            for Type {
            fn decode(__decoder: &mut __D) -> Self {
                Type {
                    offset: ::rustc_serialize::Decodable::decode(__decoder),
                    size: ::rustc_serialize::Decodable::decode(__decoder),
                    kind: ::rustc_serialize::Decodable::decode(__decoder),
                    child: ::rustc_serialize::Decodable::decode(__decoder),
                }
            }
        }
    };Decodable, #[automatically_derived]
impl ::core::fmt::Debug for Type {
    #[inline]
    fn fmt(&self, f: &mut ::core::fmt::Formatter) -> ::core::fmt::Result {
        ::core::fmt::Formatter::debug_struct_field4_finish(f, "Type",
            "offset", &self.offset, "size", &self.size, "kind", &self.kind,
            "child", &&self.child)
    }
}Debug, const _: () =
    {
        impl<__CTX> ::rustc_data_structures::stable_hasher::HashStable<__CTX>
            for Type where __CTX: crate::HashStableContext {
            #[inline]
            fn hash_stable(&self, __hcx: &mut __CTX,
                __hasher:
                    &mut ::rustc_data_structures::stable_hasher::StableHasher) {
                match *self {
                    Type {
                        offset: ref __binding_0,
                        size: ref __binding_1,
                        kind: ref __binding_2,
                        child: ref __binding_3 } => {
                        { __binding_0.hash_stable(__hcx, __hasher); }
                        { __binding_1.hash_stable(__hcx, __hasher); }
                        { __binding_2.hash_stable(__hcx, __hasher); }
                        { __binding_3.hash_stable(__hcx, __hasher); }
                    }
                }
            }
        }
    };HashStable_Generic)]
69pub struct Type {
70    pub offset: isize,
71    pub size: usize,
72    pub kind: Kind,
73    pub child: TypeTree,
74}
75
76impl Type {
77    pub fn add_offset(self, add: isize) -> Self {
78        let offset = match self.offset {
79            -1 => add,
80            x => add + x,
81        };
82
83        Self { size: self.size, kind: self.kind, child: self.child, offset }
84    }
85}
86
87impl fmt::Display for Type {
88    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89        <Self as fmt::Debug>::fmt(self, f)
90    }
91}