rustc_codegen_llvm/llvm/
enzyme_ffi.rs

1#![expect(dead_code)]
2
3use libc::{c_char, c_uint};
4
5use super::MetadataKindId;
6use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
7use crate::llvm::{Bool, Builder};
8
9// TypeTree types
10pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
11
12#[repr(C)]
13#[derive(Debug, Copy, Clone)]
14pub(crate) struct EnzymeTypeTree {
15    _unused: [u8; 0],
16}
17
18#[repr(u32)]
19#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
20#[allow(non_camel_case_types)]
21pub(crate) enum CConcreteType {
22    DT_Anything = 0,
23    DT_Integer = 1,
24    DT_Pointer = 2,
25    DT_Half = 3,
26    DT_Float = 4,
27    DT_Double = 5,
28    DT_Unknown = 6,
29    DT_FP128 = 9,
30}
31
32pub(crate) struct TypeTree {
33    pub(crate) inner: CTypeTreeRef,
34}
35
36#[link(name = "llvm-wrapper", kind = "static")]
37unsafe extern "C" {
38    // Enzyme
39    pub(crate) safe fn LLVMRustHasMetadata(I: &Value, KindID: MetadataKindId) -> bool;
40    pub(crate) fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
41    pub(crate) fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
42    pub(crate) fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
43    pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
44    pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
45    pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
46    pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
47    pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
48    pub(crate) fn LLVMRustHasFnAttribute(
49        F: &Value,
50        Name: *const c_char,
51        NameLen: libc::size_t,
52    ) -> bool;
53    pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
54    pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
55    pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
56    pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
57        Fn: &Value,
58        index: c_uint,
59        kind: AttributeKind,
60    );
61    pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
62    pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
63    pub(crate) fn LLVMRustGetFunctionCall(
64        F: &Value,
65        name: *const c_char,
66        NameLen: libc::size_t,
67    ) -> Option<&Value>;
68
69}
70
71unsafe extern "C" {
72    // Enzyme
73    pub(crate) fn LLVMDumpModule(M: &Module);
74    pub(crate) fn LLVMDumpValue(V: &Value);
75    pub(crate) fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
76    pub(crate) fn LLVMGetReturnType(T: &Type) -> &Type;
77    pub(crate) fn LLVMGetParams(Fnc: &Value, params: *mut &Value);
78    pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
79}
80
81#[repr(C)]
82#[derive(Copy, Clone, PartialEq)]
83pub(crate) enum LLVMRustVerifierFailureAction {
84    LLVMAbortProcessAction = 0,
85    LLVMPrintMessageAction = 1,
86    LLVMReturnStatusAction = 2,
87}
88
89pub(crate) use self::Enzyme_AD::*;
90
91pub(crate) mod Enzyme_AD {
92    use std::ffi::{c_char, c_void};
93    use std::sync::{Mutex, MutexGuard, OnceLock};
94
95    use rustc_middle::bug;
96    use rustc_session::config::{Sysroot, host_tuple};
97    use rustc_session::filesearch;
98
99    use super::{CConcreteType, CTypeTreeRef, Context};
100    use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor};
101
102    type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8);
103    type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char);
104
105    type EnzymeNewTypeTreeFn = unsafe extern "C" fn() -> CTypeTreeRef;
106    type EnzymeNewTypeTreeCTFn = unsafe extern "C" fn(CConcreteType, &Context) -> CTypeTreeRef;
107    type EnzymeNewTypeTreeTRFn = unsafe extern "C" fn(CTypeTreeRef) -> CTypeTreeRef;
108    type EnzymeFreeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef);
109    type EnzymeMergeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef, CTypeTreeRef) -> bool;
110    type EnzymeTypeTreeOnlyEqFn = unsafe extern "C" fn(CTypeTreeRef, i64);
111    type EnzymeTypeTreeData0EqFn = unsafe extern "C" fn(CTypeTreeRef);
112    type EnzymeTypeTreeShiftIndiciesEqFn =
113        unsafe extern "C" fn(CTypeTreeRef, *const c_char, i64, i64, u64);
114    type EnzymeTypeTreeInsertEqFn =
115        unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context);
116    type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char;
117    type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char);
118
119    #[allow(non_snake_case)]
120    pub(crate) struct EnzymeWrapper {
121        EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
122        EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
123        EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
124        EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
125        EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
126        EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
127        EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
128        EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
129        EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
130        EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
131        EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
132
133        EnzymePrintPerf: *mut c_void,
134        EnzymePrintActivity: *mut c_void,
135        EnzymePrintType: *mut c_void,
136        EnzymeFunctionToAnalyze: *mut c_void,
137        EnzymePrint: *mut c_void,
138        EnzymeStrictAliasing: *mut c_void,
139        EnzymeInline: *mut c_void,
140        EnzymeMaxTypeDepth: *mut c_void,
141        RustTypeRules: *mut c_void,
142        looseTypeAnalysis: *mut c_void,
143
144        EnzymeSetCLBool: EnzymeSetCLBoolFn,
145        EnzymeSetCLString: EnzymeSetCLStringFn,
146        pub registerEnzymeAndPassPipeline: *const c_void,
147        lib: libloading::Library,
148    }
149
150    unsafe impl Sync for EnzymeWrapper {}
151    unsafe impl Send for EnzymeWrapper {}
152
153    fn load_ptr_by_symbol_mut_void(
154        lib: &libloading::Library,
155        bytes: &[u8],
156    ) -> Result<*mut c_void, Box<dyn std::error::Error>> {
157        unsafe {
158            let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
159            // libloading = 0.9.0: try_as_raw_ptr always succeeds and returns Some
160            let s = s.try_as_raw_ptr().unwrap();
161            Ok(s)
162        }
163    }
164
165    // e.g.
166    // load_ptrs_by_symbols_mut_void(ABC, XYZ);
167    // =>
168    // let ABC = load_ptr_mut_void(&lib, b"ABC")?;
169    // let XYZ = load_ptr_mut_void(&lib, b"XYZ")?;
170    macro_rules! load_ptrs_by_symbols_mut_void {
171        ($lib:expr, $($name:ident),* $(,)?) => {
172            $(
173                #[allow(non_snake_case)]
174                let $name = load_ptr_by_symbol_mut_void(&$lib, stringify!($name).as_bytes())?;
175            )*
176        };
177    }
178
179    // e.g.
180    // load_ptrs_by_symbols_fn(ABC: ABCFn, XYZ: XYZFn);
181    // =>
182    // let ABC: libloading::Symbol<'_, ABCFn> = unsafe { lib.get(b"ABC")? };
183    // let XYZ: libloading::Symbol<'_, XYZFn> = unsafe { lib.get(b"XYZ")? };
184    macro_rules! load_ptrs_by_symbols_fn {
185        ($lib:expr, $($name:ident : $ty:ty),* $(,)?) => {
186            $(
187                #[allow(non_snake_case)]
188                let $name: $ty = *unsafe { $lib.get::<$ty>(stringify!($name).as_bytes())? };
189            )*
190        };
191    }
192
193    static ENZYME_INSTANCE: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
194
195    impl EnzymeWrapper {
196        /// Initialize EnzymeWrapper with the given sysroot if not already initialized.
197        /// Safe to call multiple times - subsequent calls are no-ops due to OnceLock.
198        pub(crate) fn get_or_init(
199            sysroot: &rustc_session::config::Sysroot,
200        ) -> Result<MutexGuard<'static, Self>, Box<dyn std::error::Error>> {
201            let mtx: &'static Mutex<EnzymeWrapper> = ENZYME_INSTANCE.get_or_try_init(|| {
202                let w = Self::call_dynamic(sysroot)?;
203                Ok::<_, Box<dyn std::error::Error>>(Mutex::new(w))
204            })?;
205
206            Ok(mtx.lock().unwrap())
207        }
208
209        /// Get the EnzymeWrapper instance. Panics if not initialized.
210        pub(crate) fn get_instance() -> MutexGuard<'static, Self> {
211            ENZYME_INSTANCE
212                .get()
213                .expect("EnzymeWrapper not initialized. Call get_or_init with sysroot first.")
214                .lock()
215                .unwrap()
216        }
217
218        pub(crate) fn new_type_tree(&self) -> CTypeTreeRef {
219            unsafe { (self.EnzymeNewTypeTree)() }
220        }
221
222        pub(crate) fn new_type_tree_ct(
223            &self,
224            t: CConcreteType,
225            ctx: &Context,
226        ) -> *mut EnzymeTypeTree {
227            unsafe { (self.EnzymeNewTypeTreeCT)(t, ctx) }
228        }
229
230        pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef {
231            unsafe { (self.EnzymeNewTypeTreeTR)(tree) }
232        }
233
234        pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) {
235            unsafe { (self.EnzymeFreeTypeTree)(tree) }
236        }
237
238        pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool {
239            unsafe { (self.EnzymeMergeTypeTree)(tree1, tree2) }
240        }
241
242        pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) {
243            unsafe { (self.EnzymeTypeTreeOnlyEq)(tree, num) }
244        }
245
246        pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) {
247            unsafe { (self.EnzymeTypeTreeData0Eq)(tree) }
248        }
249
250        pub(crate) fn shift_indicies_eq(
251            &self,
252            tree: CTypeTreeRef,
253            data_layout: *const c_char,
254            offset: i64,
255            max_size: i64,
256            add_offset: u64,
257        ) {
258            unsafe {
259                (self.EnzymeTypeTreeShiftIndiciesEq)(
260                    tree,
261                    data_layout,
262                    offset,
263                    max_size,
264                    add_offset,
265                )
266            }
267        }
268
269        pub(crate) fn tree_insert_eq(
270            &self,
271            tree: CTypeTreeRef,
272            indices: *const i64,
273            len: usize,
274            ct: CConcreteType,
275            ctx: &Context,
276        ) {
277            unsafe { (self.EnzymeTypeTreeInsertEq)(tree, indices, len, ct, ctx) }
278        }
279
280        pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char {
281            unsafe { (self.EnzymeTypeTreeToString)(tree) }
282        }
283
284        pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
285            unsafe { (self.EnzymeTypeTreeToStringFree)(ch) }
286        }
287
288        pub(crate) fn get_max_type_depth(&self) -> usize {
289            unsafe { std::ptr::read::<u32>(self.EnzymeMaxTypeDepth as *const u32) as usize }
290        }
291
292        pub(crate) fn set_print_perf(&mut self, print: bool) {
293            unsafe {
294                (self.EnzymeSetCLBool)(self.EnzymePrintPerf, print as u8);
295            }
296        }
297
298        pub(crate) fn set_print_activity(&mut self, print: bool) {
299            unsafe {
300                (self.EnzymeSetCLBool)(self.EnzymePrintActivity, print as u8);
301            }
302        }
303
304        pub(crate) fn set_print_type(&mut self, print: bool) {
305            unsafe {
306                (self.EnzymeSetCLBool)(self.EnzymePrintType, print as u8);
307            }
308        }
309
310        pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
311            let c_fun_name = std::ffi::CString::new(fun_name)
312                .unwrap_or_else(|err| bug!("failed to set_print_type_fun: {err}"));
313            unsafe {
314                (self.EnzymeSetCLString)(
315                    self.EnzymeFunctionToAnalyze,
316                    c_fun_name.as_ptr() as *const c_char,
317                );
318            }
319        }
320
321        pub(crate) fn set_print(&mut self, print: bool) {
322            unsafe {
323                (self.EnzymeSetCLBool)(self.EnzymePrint, print as u8);
324            }
325        }
326
327        pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
328            unsafe {
329                (self.EnzymeSetCLBool)(self.EnzymeStrictAliasing, strict as u8);
330            }
331        }
332
333        pub(crate) fn set_loose_types(&mut self, loose: bool) {
334            unsafe {
335                (self.EnzymeSetCLBool)(self.looseTypeAnalysis, loose as u8);
336            }
337        }
338
339        pub(crate) fn set_inline(&mut self, val: bool) {
340            unsafe {
341                (self.EnzymeSetCLBool)(self.EnzymeInline, val as u8);
342            }
343        }
344
345        pub(crate) fn set_rust_rules(&mut self, val: bool) {
346            unsafe {
347                (self.EnzymeSetCLBool)(self.RustTypeRules, val as u8);
348            }
349        }
350
351        #[allow(non_snake_case)]
352        fn call_dynamic(
353            sysroot: &rustc_session::config::Sysroot,
354        ) -> Result<Self, Box<dyn std::error::Error>> {
355            let enzyme_path = Self::get_enzyme_path(sysroot)?;
356            let lib = unsafe { libloading::Library::new(enzyme_path)? };
357
358            load_ptrs_by_symbols_fn!(
359                lib,
360                EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
361                EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
362                EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
363                EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
364                EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
365                EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
366                EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
367                EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
368                EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
369                EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
370                EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
371                EnzymeSetCLBool: EnzymeSetCLBoolFn,
372                EnzymeSetCLString: EnzymeSetCLStringFn,
373            );
374
375            load_ptrs_by_symbols_mut_void!(
376                lib,
377                registerEnzymeAndPassPipeline,
378                EnzymePrintPerf,
379                EnzymePrintActivity,
380                EnzymePrintType,
381                EnzymeFunctionToAnalyze,
382                EnzymePrint,
383                EnzymeStrictAliasing,
384                EnzymeInline,
385                EnzymeMaxTypeDepth,
386                RustTypeRules,
387                looseTypeAnalysis,
388            );
389
390            Ok(Self {
391                EnzymeNewTypeTree,
392                EnzymeNewTypeTreeCT,
393                EnzymeNewTypeTreeTR,
394                EnzymeFreeTypeTree,
395                EnzymeMergeTypeTree,
396                EnzymeTypeTreeOnlyEq,
397                EnzymeTypeTreeData0Eq,
398                EnzymeTypeTreeShiftIndiciesEq,
399                EnzymeTypeTreeInsertEq,
400                EnzymeTypeTreeToString,
401                EnzymeTypeTreeToStringFree,
402                EnzymePrintPerf,
403                EnzymePrintActivity,
404                EnzymePrintType,
405                EnzymeFunctionToAnalyze,
406                EnzymePrint,
407                EnzymeStrictAliasing,
408                EnzymeInline,
409                EnzymeMaxTypeDepth,
410                RustTypeRules,
411                looseTypeAnalysis,
412                EnzymeSetCLBool,
413                EnzymeSetCLString,
414                registerEnzymeAndPassPipeline,
415                lib,
416            })
417        }
418
419        fn get_enzyme_path(sysroot: &Sysroot) -> Result<String, String> {
420            let llvm_version_major = unsafe { LLVMRustVersionMajor() };
421
422            let path_buf = sysroot
423                .all_paths()
424                .map(|sysroot_path| {
425                    filesearch::make_target_lib_path(sysroot_path, host_tuple())
426                        .join("lib")
427                        .with_file_name(format!("libEnzyme-{llvm_version_major}"))
428                        .with_extension(std::env::consts::DLL_EXTENSION)
429                })
430                .find(|f| f.exists())
431                .ok_or_else(|| {
432                    let candidates = sysroot
433                        .all_paths()
434                        .map(|p| p.join("lib").display().to_string())
435                        .collect::<Vec<String>>()
436                        .join("\n* ");
437                    format!(
438                        "failed to find a `libEnzyme-{llvm_version_major}` folder \
439                    in the sysroot candidates:\n* {candidates}"
440                    )
441                })?;
442
443            Ok(path_buf
444                .to_str()
445                .ok_or_else(|| format!("invalid UTF-8 in path: {}", path_buf.display()))?
446                .to_string())
447        }
448    }
449}
450
451impl TypeTree {
452    pub(crate) fn new() -> TypeTree {
453        let wrapper = EnzymeWrapper::get_instance();
454        let inner = wrapper.new_type_tree();
455        TypeTree { inner }
456    }
457
458    pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
459        let wrapper = EnzymeWrapper::get_instance();
460        let inner = wrapper.new_type_tree_ct(t, ctx);
461        TypeTree { inner }
462    }
463
464    pub(crate) fn merge(self, other: Self) -> Self {
465        let wrapper = EnzymeWrapper::get_instance();
466        wrapper.merge_type_tree(self.inner, other.inner);
467        drop(other);
468        self
469    }
470
471    #[must_use]
472    pub(crate) fn shift(
473        self,
474        layout: &str,
475        offset: isize,
476        max_size: isize,
477        add_offset: usize,
478    ) -> Self {
479        let layout = std::ffi::CString::new(layout).unwrap();
480        let wrapper = EnzymeWrapper::get_instance();
481        wrapper.shift_indicies_eq(
482            self.inner,
483            layout.as_ptr(),
484            offset as i64,
485            max_size as i64,
486            add_offset as u64,
487        );
488
489        self
490    }
491
492    pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
493        let wrapper = EnzymeWrapper::get_instance();
494        wrapper.tree_insert_eq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
495    }
496}
497
498impl Clone for TypeTree {
499    fn clone(&self) -> Self {
500        let wrapper = EnzymeWrapper::get_instance();
501        let inner = wrapper.new_type_tree_tr(self.inner);
502        TypeTree { inner }
503    }
504}
505
506impl std::fmt::Display for TypeTree {
507    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508        let wrapper = EnzymeWrapper::get_instance();
509        let ptr = wrapper.tree_to_string(self.inner);
510        let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
511        match cstr.to_str() {
512            Ok(x) => write!(f, "{}", x)?,
513            Err(err) => write!(f, "could not parse: {}", err)?,
514        }
515
516        // delete C string pointer
517        wrapper.tree_to_string_free(ptr);
518
519        Ok(())
520    }
521}
522
523impl std::fmt::Debug for TypeTree {
524    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525        <Self as std::fmt::Display>::fmt(self, f)
526    }
527}
528
529impl Drop for TypeTree {
530    fn drop(&mut self) {
531        let wrapper = EnzymeWrapper::get_instance();
532        wrapper.free_type_tree(self.inner)
533    }
534}