rustc_codegen_llvm/
typetree.rs

1use rustc_ast::expand::typetree::FncTree;
2#[cfg(feature = "llvm_enzyme")]
3use {
4    crate::attributes,
5    crate::llvm::EnzymeWrapper,
6    rustc_ast::expand::typetree::TypeTree as RustTypeTree,
7    std::ffi::{CString, c_char, c_uint},
8};
9
10use crate::llvm::{self, Value};
11
12#[cfg(feature = "llvm_enzyme")]
13fn to_enzyme_typetree(
14    rust_typetree: RustTypeTree,
15    _data_layout: &str,
16    llcx: &llvm::Context,
17) -> llvm::TypeTree {
18    let mut enzyme_tt = llvm::TypeTree::new();
19    process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
20    enzyme_tt
21}
22#[cfg(feature = "llvm_enzyme")]
23fn process_typetree_recursive(
24    enzyme_tt: &mut llvm::TypeTree,
25    rust_typetree: &RustTypeTree,
26    parent_indices: &[i64],
27    llcx: &llvm::Context,
28) {
29    for rust_type in &rust_typetree.0 {
30        let concrete_type = match rust_type.kind {
31            rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
32            rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
33            rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
34            rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
35            rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
36            rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
37            rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128,
38            rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
39        };
40
41        let mut indices = parent_indices.to_vec();
42        if !parent_indices.is_empty() {
43            indices.push(rust_type.offset as i64);
44        } else if rust_type.offset == -1 {
45            indices.push(-1);
46        } else {
47            indices.push(rust_type.offset as i64);
48        }
49
50        enzyme_tt.insert(&indices, concrete_type, llcx);
51
52        if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
53            && !rust_type.child.0.is_empty()
54        {
55            process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
56        }
57    }
58}
59
60#[cfg(feature = "llvm_enzyme")]
61pub(crate) fn add_tt<'ll>(
62    llmod: &'ll llvm::Module,
63    llcx: &'ll llvm::Context,
64    fn_def: &'ll Value,
65    tt: FncTree,
66) {
67    let inputs = tt.args;
68    let ret_tt: RustTypeTree = tt.ret;
69
70    let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
71    let llvm_data_layout =
72        std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
73            .expect("got a non-UTF8 data-layout from LLVM");
74
75    let attr_name = "enzyme_type";
76    let c_attr_name = CString::new(attr_name).unwrap();
77
78    for (i, input) in inputs.iter().enumerate() {
79        unsafe {
80            let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
81            let enzyme_wrapper = EnzymeWrapper::get_instance();
82            let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
83            let c_str = std::ffi::CStr::from_ptr(c_str);
84
85            let attr = llvm::LLVMCreateStringAttribute(
86                llcx,
87                c_attr_name.as_ptr(),
88                c_attr_name.as_bytes().len() as c_uint,
89                c_str.as_ptr(),
90                c_str.to_bytes().len() as c_uint,
91            );
92
93            attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
94            enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
95        }
96    }
97
98    unsafe {
99        let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
100        let enzyme_wrapper = EnzymeWrapper::get_instance();
101        let c_str = enzyme_wrapper.tree_to_string(enzyme_tt.inner);
102        let c_str = std::ffi::CStr::from_ptr(c_str);
103
104        let ret_attr = llvm::LLVMCreateStringAttribute(
105            llcx,
106            c_attr_name.as_ptr(),
107            c_attr_name.as_bytes().len() as c_uint,
108            c_str.as_ptr(),
109            c_str.to_bytes().len() as c_uint,
110        );
111
112        attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
113        enzyme_wrapper.tree_to_string_free(c_str.as_ptr());
114    }
115}
116
117#[cfg(not(feature = "llvm_enzyme"))]
118pub(crate) fn add_tt<'ll>(
119    _llmod: &'ll llvm::Module,
120    _llcx: &'ll llvm::Context,
121    _fn_def: &'ll Value,
122    _tt: FncTree,
123) {
124    unimplemented!()
125}