rustc_ast/expand/
typetree.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
//! This module contains the definition of the `TypeTree` and `Type` structs.
//! They are thin Rust wrappers around the TypeTrees used by Enzyme as the LLVM based autodiff
//! backend. The Enzyme TypeTrees currently have various limitations and should be rewritten, so the
//! Rust frontend obviously has the same limitations. The main motivation of TypeTrees is to
//! represent how a type looks like "in memory". Enzyme can deduce this based on usage patterns in
//! the user code, but this is extremely slow and not even always sufficient. As such we lower some
//! information from rustc to help Enzyme. For a full explanation of their design it is necessary to
//! analyze the implementation in Enzyme core itself. As a rough summary, `-1` in Enzyme speech means
//! everywhere. That is `{0:-1: Float}` means at index 0 you have a ptr, if you dereference it it
//! will be floats everywhere. Thus `* f32`. If you have `{-1:int}` it means int's everywhere,
//! e.g. [i32; N]. `{0:-1:-1 float}` then means one pointer at offset 0, if you dereference it there
//! will be only pointers, if you dereference these new pointers they will point to array of floats.
//! Generally, it allows byte-specific descriptions.
//! FIXME: This description might be partly inaccurate and should be extended, along with
//! adding documentation to the corresponding Enzyme core code.
//! FIXME: Rewrite the TypeTree logic in Enzyme core to reduce the need for the rustc frontend to
//! provide typetree information.
//! FIXME: We should also re-evaluate where we create TypeTrees from Rust types, since MIR
//! representations of some types might not be accurate. For example a vector of floats might be
//! represented as a vector of u8s in MIR in some cases.

use std::fmt;

use crate::expand::{Decodable, Encodable, HashStable_Generic};

#[derive(Clone, Copy, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub enum Kind {
    Anything,
    Integer,
    Pointer,
    Half,
    Float,
    Double,
    Unknown,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct TypeTree(pub Vec<Type>);

impl TypeTree {
    pub fn new() -> Self {
        Self(Vec::new())
    }
    pub fn all_ints() -> Self {
        Self(vec![Type { offset: -1, size: 1, kind: Kind::Integer, child: TypeTree::new() }])
    }
    pub fn int(size: usize) -> Self {
        let mut ints = Vec::with_capacity(size);
        for i in 0..size {
            ints.push(Type {
                offset: i as isize,
                size: 1,
                kind: Kind::Integer,
                child: TypeTree::new(),
            });
        }
        Self(ints)
    }
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct FncTree {
    pub args: Vec<TypeTree>,
    pub ret: TypeTree,
}

#[derive(Clone, Eq, PartialEq, Encodable, Decodable, Debug, HashStable_Generic)]
pub struct Type {
    pub offset: isize,
    pub size: usize,
    pub kind: Kind,
    pub child: TypeTree,
}

impl Type {
    pub fn add_offset(self, add: isize) -> Self {
        let offset = match self.offset {
            -1 => add,
            x => add + x,
        };

        Self { size: self.size, kind: self.kind, child: self.child, offset }
    }
}

impl fmt::Display for Type {
    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
        <Self as fmt::Debug>::fmt(self, f)
    }
}