rustc_codegen_llvm/
typetree.rs

1use std::ffi::{CString, c_char, c_uint};
2
3use rustc_ast::expand::typetree::{FncTree, TypeTree as RustTypeTree};
4
5use crate::attributes;
6use crate::llvm::{self, EnzymeWrapper, Value};
7
8fn to_enzyme_typetree(
9    rust_typetree: RustTypeTree,
10    _data_layout: &str,
11    llcx: &llvm::Context,
12) -> llvm::TypeTree {
13    let mut enzyme_tt = llvm::TypeTree::new();
14    process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
15    enzyme_tt
16}
17fn process_typetree_recursive(
18    enzyme_tt: &mut llvm::TypeTree,
19    rust_typetree: &RustTypeTree,
20    parent_indices: &[i64],
21    llcx: &llvm::Context,
22) {
23    for rust_type in &rust_typetree.0 {
24        let concrete_type = match rust_type.kind {
25            rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
26            rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
27            rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
28            rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
29            rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
30            rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
31            rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128,
32            rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
33        };
34
35        let mut indices = parent_indices.to_vec();
36        if !parent_indices.is_empty() {
37            indices.push(rust_type.offset as i64);
38        } else if rust_type.offset == -1 {
39            indices.push(-1);
40        } else {
41            indices.push(rust_type.offset as i64);
42        }
43
44        enzyme_tt.insert(&indices, concrete_type, llcx);
45
46        if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
47            && !rust_type.child.0.is_empty()
48        {
49            process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
50        }
51    }
52}
53
54#[cfg_attr(not(feature = "llvm_enzyme"), allow(unused))]
55pub(crate) fn add_tt<'ll>(
56    llmod: &'ll llvm::Module,
57    llcx: &'ll llvm::Context,
58    fn_def: &'ll Value,
59    tt: FncTree,
60) {
61    // TypeTree processing uses functions from Enzyme, which we might not have available if we did
62    // not build this compiler with `llvm_enzyme`. This feature is not strictly necessary, but
63    // skipping this function increases the chance that Enzyme fails to compile some code.
64    // FIXME(autodiff): In the future we should conditionally run this function even without the
65    // `llvm_enzyme` feature, in case that libEnzyme was provided via rustup.
66    #[cfg(not(feature = "llvm_enzyme"))]
67    return;
68
69    let inputs = tt.args;
70    let ret_tt: RustTypeTree = tt.ret;
71
72    let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
73    let llvm_data_layout =
74        std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
75            .expect("got a non-UTF8 data-layout from LLVM");
76
77    let attr_name = "enzyme_type";
78    let c_attr_name = CString::new(attr_name).unwrap();
79
80    for (i, input) in inputs.iter().enumerate() {
81        unsafe {
82            let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
83            let enzyme_wrapper = EnzymeWrapper::get_instance();
84            let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
85            let c_str = std::ffi::CStr::from_ptr(c_str);
86
87            let attr = llvm::LLVMCreateStringAttribute(
88                llcx,
89                c_attr_name.as_ptr(),
90                c_attr_name.as_bytes().len() as c_uint,
91                c_str.as_ptr(),
92                c_str.to_bytes().len() as c_uint,
93            );
94
95            attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
96            enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
97        }
98    }
99
100    unsafe {
101        let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
102        let enzyme_wrapper = EnzymeWrapper::get_instance();
103        let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
104        let c_str = std::ffi::CStr::from_ptr(c_str);
105
106        let ret_attr = llvm::LLVMCreateStringAttribute(
107            llcx,
108            c_attr_name.as_ptr(),
109            c_attr_name.as_bytes().len() as c_uint,
110            c_str.as_ptr(),
111            c_str.to_bytes().len() as c_uint,
112        );
113
114        attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
115        enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
116    }
117}