rustc_codegen_llvm/
typetree.rs

1use rustc_ast::expand::typetree::FncTree;
2#[cfg(feature = "llvm_enzyme")]
3use {
4    crate::attributes,
5    rustc_ast::expand::typetree::TypeTree as RustTypeTree,
6    std::ffi::{CString, c_char, c_uint},
7};
8
9use crate::llvm::{self, Value};
10
11#[cfg(feature = "llvm_enzyme")]
12fn to_enzyme_typetree(
13    rust_typetree: RustTypeTree,
14    _data_layout: &str,
15    llcx: &llvm::Context,
16) -> llvm::TypeTree {
17    let mut enzyme_tt = llvm::TypeTree::new();
18    process_typetree_recursive(&mut enzyme_tt, &rust_typetree, &[], llcx);
19    enzyme_tt
20}
21#[cfg(feature = "llvm_enzyme")]
22fn process_typetree_recursive(
23    enzyme_tt: &mut llvm::TypeTree,
24    rust_typetree: &RustTypeTree,
25    parent_indices: &[i64],
26    llcx: &llvm::Context,
27) {
28    for rust_type in &rust_typetree.0 {
29        let concrete_type = match rust_type.kind {
30            rustc_ast::expand::typetree::Kind::Anything => llvm::CConcreteType::DT_Anything,
31            rustc_ast::expand::typetree::Kind::Integer => llvm::CConcreteType::DT_Integer,
32            rustc_ast::expand::typetree::Kind::Pointer => llvm::CConcreteType::DT_Pointer,
33            rustc_ast::expand::typetree::Kind::Half => llvm::CConcreteType::DT_Half,
34            rustc_ast::expand::typetree::Kind::Float => llvm::CConcreteType::DT_Float,
35            rustc_ast::expand::typetree::Kind::Double => llvm::CConcreteType::DT_Double,
36            rustc_ast::expand::typetree::Kind::F128 => llvm::CConcreteType::DT_FP128,
37            rustc_ast::expand::typetree::Kind::Unknown => llvm::CConcreteType::DT_Unknown,
38        };
39
40        let mut indices = parent_indices.to_vec();
41        if !parent_indices.is_empty() {
42            indices.push(rust_type.offset as i64);
43        } else if rust_type.offset == -1 {
44            indices.push(-1);
45        } else {
46            indices.push(rust_type.offset as i64);
47        }
48
49        enzyme_tt.insert(&indices, concrete_type, llcx);
50
51        if rust_type.kind == rustc_ast::expand::typetree::Kind::Pointer
52            && !rust_type.child.0.is_empty()
53        {
54            process_typetree_recursive(enzyme_tt, &rust_type.child, &indices, llcx);
55        }
56    }
57}
58
59#[cfg(feature = "llvm_enzyme")]
60pub(crate) fn add_tt<'ll>(
61    llmod: &'ll llvm::Module,
62    llcx: &'ll llvm::Context,
63    fn_def: &'ll Value,
64    tt: FncTree,
65) {
66    let inputs = tt.args;
67    let ret_tt: RustTypeTree = tt.ret;
68
69    let llvm_data_layout: *const c_char = unsafe { llvm::LLVMGetDataLayoutStr(&*llmod) };
70    let llvm_data_layout =
71        std::str::from_utf8(unsafe { std::ffi::CStr::from_ptr(llvm_data_layout) }.to_bytes())
72            .expect("got a non-UTF8 data-layout from LLVM");
73
74    let attr_name = "enzyme_type";
75    let c_attr_name = CString::new(attr_name).unwrap();
76
77    for (i, input) in inputs.iter().enumerate() {
78        unsafe {
79            let enzyme_tt = to_enzyme_typetree(input.clone(), llvm_data_layout, llcx);
80            let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
81            let c_str = std::ffi::CStr::from_ptr(c_str);
82
83            let attr = llvm::LLVMCreateStringAttribute(
84                llcx,
85                c_attr_name.as_ptr(),
86                c_attr_name.as_bytes().len() as c_uint,
87                c_str.as_ptr(),
88                c_str.to_bytes().len() as c_uint,
89            );
90
91            attributes::apply_to_llfn(fn_def, llvm::AttributePlace::Argument(i as u32), &[attr]);
92            llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
93        }
94    }
95
96    unsafe {
97        let enzyme_tt = to_enzyme_typetree(ret_tt, llvm_data_layout, llcx);
98        let c_str = llvm::EnzymeTypeTreeToString(enzyme_tt.inner);
99        let c_str = std::ffi::CStr::from_ptr(c_str);
100
101        let ret_attr = llvm::LLVMCreateStringAttribute(
102            llcx,
103            c_attr_name.as_ptr(),
104            c_attr_name.as_bytes().len() as c_uint,
105            c_str.as_ptr(),
106            c_str.to_bytes().len() as c_uint,
107        );
108
109        attributes::apply_to_llfn(fn_def, llvm::AttributePlace::ReturnValue, &[ret_attr]);
110        llvm::EnzymeTypeTreeToStringFree(c_str.as_ptr());
111    }
112}
113
114#[cfg(not(feature = "llvm_enzyme"))]
115pub(crate) fn add_tt<'ll>(
116    _llmod: &'ll llvm::Module,
117    _llcx: &'ll llvm::Context,
118    _fn_def: &'ll Value,
119    _tt: FncTree,
120) {
121    unimplemented!()
122}