rustc_codegen_llvm/llvm/
enzyme_ffi.rs

1#![expect(dead_code)]
2
3use libc::{c_char, c_uint};
4
5use super::MetadataKindId;
6use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
7use crate::llvm::{Bool, Builder};
8
9// TypeTree types
10pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
11
12#[repr(C)]
13#[derive(Debug, Copy, Clone)]
14pub(crate) struct EnzymeTypeTree {
15    _unused: [u8; 0],
16}
17
18#[repr(u32)]
19#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
20#[allow(non_camel_case_types)]
21pub(crate) enum CConcreteType {
22    DT_Anything = 0,
23    DT_Integer = 1,
24    DT_Pointer = 2,
25    DT_Half = 3,
26    DT_Float = 4,
27    DT_Double = 5,
28    DT_Unknown = 6,
29    DT_FP128 = 9,
30}
31
32pub(crate) struct TypeTree {
33    pub(crate) inner: CTypeTreeRef,
34}
35
36#[link(name = "llvm-wrapper", kind = "static")]
37unsafe extern "C" {
38    // Enzyme
39    pub(crate) safe fn LLVMRustHasMetadata(I: &Value, KindID: MetadataKindId) -> bool;
40    pub(crate) fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
41    pub(crate) fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
42    pub(crate) fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
43    pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
44    pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
45    pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
46    pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
47    pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
48    pub(crate) fn LLVMRustHasFnAttribute(
49        F: &Value,
50        Name: *const c_char,
51        NameLen: libc::size_t,
52    ) -> bool;
53    pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
54    pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
55    pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
56    pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
57        Fn: &Value,
58        index: c_uint,
59        kind: AttributeKind,
60    );
61    pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
62    pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
63    pub(crate) fn LLVMRustGetFunctionCall(
64        F: &Value,
65        name: *const c_char,
66        NameLen: libc::size_t,
67    ) -> Option<&Value>;
68
69}
70
71unsafe extern "C" {
72    // Enzyme
73    pub(crate) fn LLVMDumpModule(M: &Module);
74    pub(crate) fn LLVMDumpValue(V: &Value);
75    pub(crate) fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
76    pub(crate) fn LLVMGetReturnType(T: &Type) -> &Type;
77    pub(crate) fn LLVMGetParams(Fnc: &Value, params: *mut &Value);
78    pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
79}
80
81#[repr(C)]
82#[derive(Copy, Clone, PartialEq)]
83pub(crate) enum LLVMRustVerifierFailureAction {
84    LLVMAbortProcessAction = 0,
85    LLVMPrintMessageAction = 1,
86    LLVMReturnStatusAction = 2,
87}
88
89pub(crate) use self::Enzyme_AD::*;
90
91pub(crate) mod Enzyme_AD {
92    use std::ffi::{c_char, c_void};
93    use std::sync::{Mutex, MutexGuard, OnceLock};
94
95    use rustc_middle::bug;
96    use rustc_session::config::{Sysroot, host_tuple};
97    use rustc_session::filesearch;
98
99    use super::{CConcreteType, CTypeTreeRef, Context};
100    use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor};
101
102    type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8);
103    type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char);
104
105    type EnzymeNewTypeTreeFn = unsafe extern "C" fn() -> CTypeTreeRef;
106    type EnzymeNewTypeTreeCTFn = unsafe extern "C" fn(CConcreteType, &Context) -> CTypeTreeRef;
107    type EnzymeNewTypeTreeTRFn = unsafe extern "C" fn(CTypeTreeRef) -> CTypeTreeRef;
108    type EnzymeFreeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef);
109    type EnzymeMergeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef, CTypeTreeRef) -> bool;
110    type EnzymeTypeTreeOnlyEqFn = unsafe extern "C" fn(CTypeTreeRef, i64);
111    type EnzymeTypeTreeData0EqFn = unsafe extern "C" fn(CTypeTreeRef);
112    type EnzymeTypeTreeShiftIndiciesEqFn =
113        unsafe extern "C" fn(CTypeTreeRef, *const c_char, i64, i64, u64);
114    type EnzymeTypeTreeInsertEqFn =
115        unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context);
116    type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char;
117    type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char);
118
119    #[allow(non_snake_case)]
120    pub(crate) struct EnzymeWrapper {
121        EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
122        EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
123        EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
124        EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
125        EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
126        EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
127        EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
128        EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
129        EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
130        EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
131        EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
132
133        EnzymePrintPerf: *mut c_void,
134        EnzymePrintActivity: *mut c_void,
135        EnzymePrintType: *mut c_void,
136        EnzymeFunctionToAnalyze: *mut c_void,
137        EnzymePrint: *mut c_void,
138        EnzymeStrictAliasing: *mut c_void,
139        EnzymeInline: *mut c_void,
140        EnzymeMaxTypeDepth: *mut c_void,
141        RustTypeRules: *mut c_void,
142        looseTypeAnalysis: *mut c_void,
143
144        EnzymeSetCLBool: EnzymeSetCLBoolFn,
145        EnzymeSetCLString: EnzymeSetCLStringFn,
146        pub registerEnzymeAndPassPipeline: *const c_void,
147        lib: libloading::Library,
148    }
149
150    unsafe impl Sync for EnzymeWrapper {}
151    unsafe impl Send for EnzymeWrapper {}
152
153    fn load_ptr_by_symbol_mut_void(
154        lib: &libloading::Library,
155        bytes: &[u8],
156    ) -> Result<*mut c_void, libloading::Error> {
157        unsafe {
158            let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
159            // libloading = 0.9.0: try_as_raw_ptr always succeeds and returns Some
160            let s = s.try_as_raw_ptr().unwrap();
161            Ok(s)
162        }
163    }
164
165    // e.g.
166    // load_ptrs_by_symbols_mut_void(ABC, XYZ);
167    // =>
168    // let ABC = load_ptr_mut_void(&lib, b"ABC")?;
169    // let XYZ = load_ptr_mut_void(&lib, b"XYZ")?;
170    macro_rules! load_ptrs_by_symbols_mut_void {
171        ($lib:expr, $($name:ident),* $(,)?) => {
172            $(
173                #[allow(non_snake_case)]
174                let $name = load_ptr_by_symbol_mut_void(&$lib, stringify!($name).as_bytes())?;
175            )*
176        };
177    }
178
179    // e.g.
180    // load_ptrs_by_symbols_fn(ABC: ABCFn, XYZ: XYZFn);
181    // =>
182    // let ABC: libloading::Symbol<'_, ABCFn> = unsafe { lib.get(b"ABC")? };
183    // let XYZ: libloading::Symbol<'_, XYZFn> = unsafe { lib.get(b"XYZ")? };
184    macro_rules! load_ptrs_by_symbols_fn {
185        ($lib:expr, $($name:ident : $ty:ty),* $(,)?) => {
186            $(
187                #[allow(non_snake_case)]
188                let $name: $ty = *unsafe { $lib.get::<$ty>(stringify!($name).as_bytes())? };
189            )*
190        };
191    }
192
193    static ENZYME_INSTANCE: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
194
195    #[derive(Debug)]
196    pub(crate) enum EnzymeLibraryError {
197        NotFound { err: String },
198        LoadFailed { err: String },
199    }
200
201    impl From<libloading::Error> for EnzymeLibraryError {
202        fn from(err: libloading::Error) -> Self {
203            Self::LoadFailed { err: format!("{err:?}") }
204        }
205    }
206
207    impl EnzymeWrapper {
208        /// Initialize EnzymeWrapper with the given sysroot if not already initialized.
209        /// Safe to call multiple times - subsequent calls are no-ops due to OnceLock.
210        pub(crate) fn get_or_init(
211            sysroot: &rustc_session::config::Sysroot,
212        ) -> Result<MutexGuard<'static, Self>, EnzymeLibraryError> {
213            let mtx: &'static Mutex<EnzymeWrapper> = ENZYME_INSTANCE.get_or_try_init(|| {
214                let w = Self::call_dynamic(sysroot)?;
215                Ok::<_, EnzymeLibraryError>(Mutex::new(w))
216            })?;
217
218            Ok(mtx.lock().unwrap())
219        }
220
221        /// Get the EnzymeWrapper instance. Panics if not initialized.
222        pub(crate) fn get_instance() -> MutexGuard<'static, Self> {
223            ENZYME_INSTANCE
224                .get()
225                .expect("EnzymeWrapper not initialized. Call get_or_init with sysroot first.")
226                .lock()
227                .unwrap()
228        }
229
230        pub(crate) fn new_type_tree(&self) -> CTypeTreeRef {
231            unsafe { (self.EnzymeNewTypeTree)() }
232        }
233
234        pub(crate) fn new_type_tree_ct(
235            &self,
236            t: CConcreteType,
237            ctx: &Context,
238        ) -> *mut EnzymeTypeTree {
239            unsafe { (self.EnzymeNewTypeTreeCT)(t, ctx) }
240        }
241
242        pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef {
243            unsafe { (self.EnzymeNewTypeTreeTR)(tree) }
244        }
245
246        pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) {
247            unsafe { (self.EnzymeFreeTypeTree)(tree) }
248        }
249
250        pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool {
251            unsafe { (self.EnzymeMergeTypeTree)(tree1, tree2) }
252        }
253
254        pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) {
255            unsafe { (self.EnzymeTypeTreeOnlyEq)(tree, num) }
256        }
257
258        pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) {
259            unsafe { (self.EnzymeTypeTreeData0Eq)(tree) }
260        }
261
262        pub(crate) fn shift_indicies_eq(
263            &self,
264            tree: CTypeTreeRef,
265            data_layout: *const c_char,
266            offset: i64,
267            max_size: i64,
268            add_offset: u64,
269        ) {
270            unsafe {
271                (self.EnzymeTypeTreeShiftIndiciesEq)(
272                    tree,
273                    data_layout,
274                    offset,
275                    max_size,
276                    add_offset,
277                )
278            }
279        }
280
281        pub(crate) fn tree_insert_eq(
282            &self,
283            tree: CTypeTreeRef,
284            indices: *const i64,
285            len: usize,
286            ct: CConcreteType,
287            ctx: &Context,
288        ) {
289            unsafe { (self.EnzymeTypeTreeInsertEq)(tree, indices, len, ct, ctx) }
290        }
291
292        pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char {
293            unsafe { (self.EnzymeTypeTreeToString)(tree) }
294        }
295
296        pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
297            unsafe { (self.EnzymeTypeTreeToStringFree)(ch) }
298        }
299
300        pub(crate) fn get_max_type_depth(&self) -> usize {
301            unsafe { std::ptr::read::<u32>(self.EnzymeMaxTypeDepth as *const u32) as usize }
302        }
303
304        pub(crate) fn set_print_perf(&mut self, print: bool) {
305            unsafe {
306                (self.EnzymeSetCLBool)(self.EnzymePrintPerf, print as u8);
307            }
308        }
309
310        pub(crate) fn set_print_activity(&mut self, print: bool) {
311            unsafe {
312                (self.EnzymeSetCLBool)(self.EnzymePrintActivity, print as u8);
313            }
314        }
315
316        pub(crate) fn set_print_type(&mut self, print: bool) {
317            unsafe {
318                (self.EnzymeSetCLBool)(self.EnzymePrintType, print as u8);
319            }
320        }
321
322        pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
323            let c_fun_name = std::ffi::CString::new(fun_name)
324                .unwrap_or_else(|err| bug!("failed to set_print_type_fun: {err}"));
325            unsafe {
326                (self.EnzymeSetCLString)(
327                    self.EnzymeFunctionToAnalyze,
328                    c_fun_name.as_ptr() as *const c_char,
329                );
330            }
331        }
332
333        pub(crate) fn set_print(&mut self, print: bool) {
334            unsafe {
335                (self.EnzymeSetCLBool)(self.EnzymePrint, print as u8);
336            }
337        }
338
339        pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
340            unsafe {
341                (self.EnzymeSetCLBool)(self.EnzymeStrictAliasing, strict as u8);
342            }
343        }
344
345        pub(crate) fn set_loose_types(&mut self, loose: bool) {
346            unsafe {
347                (self.EnzymeSetCLBool)(self.looseTypeAnalysis, loose as u8);
348            }
349        }
350
351        pub(crate) fn set_inline(&mut self, val: bool) {
352            unsafe {
353                (self.EnzymeSetCLBool)(self.EnzymeInline, val as u8);
354            }
355        }
356
357        pub(crate) fn set_rust_rules(&mut self, val: bool) {
358            unsafe {
359                (self.EnzymeSetCLBool)(self.RustTypeRules, val as u8);
360            }
361        }
362
363        #[allow(non_snake_case)]
364        fn call_dynamic(
365            sysroot: &rustc_session::config::Sysroot,
366        ) -> Result<Self, EnzymeLibraryError> {
367            let enzyme_path = Self::get_enzyme_path(sysroot)?;
368            let lib = unsafe { libloading::Library::new(enzyme_path)? };
369
370            load_ptrs_by_symbols_fn!(
371                lib,
372                EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
373                EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
374                EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
375                EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
376                EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
377                EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
378                EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
379                EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
380                EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
381                EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
382                EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
383                EnzymeSetCLBool: EnzymeSetCLBoolFn,
384                EnzymeSetCLString: EnzymeSetCLStringFn,
385            );
386
387            load_ptrs_by_symbols_mut_void!(
388                lib,
389                registerEnzymeAndPassPipeline,
390                EnzymePrintPerf,
391                EnzymePrintActivity,
392                EnzymePrintType,
393                EnzymeFunctionToAnalyze,
394                EnzymePrint,
395                EnzymeStrictAliasing,
396                EnzymeInline,
397                EnzymeMaxTypeDepth,
398                RustTypeRules,
399                looseTypeAnalysis,
400            );
401
402            Ok(Self {
403                EnzymeNewTypeTree,
404                EnzymeNewTypeTreeCT,
405                EnzymeNewTypeTreeTR,
406                EnzymeFreeTypeTree,
407                EnzymeMergeTypeTree,
408                EnzymeTypeTreeOnlyEq,
409                EnzymeTypeTreeData0Eq,
410                EnzymeTypeTreeShiftIndiciesEq,
411                EnzymeTypeTreeInsertEq,
412                EnzymeTypeTreeToString,
413                EnzymeTypeTreeToStringFree,
414                EnzymePrintPerf,
415                EnzymePrintActivity,
416                EnzymePrintType,
417                EnzymeFunctionToAnalyze,
418                EnzymePrint,
419                EnzymeStrictAliasing,
420                EnzymeInline,
421                EnzymeMaxTypeDepth,
422                RustTypeRules,
423                looseTypeAnalysis,
424                EnzymeSetCLBool,
425                EnzymeSetCLString,
426                registerEnzymeAndPassPipeline,
427                lib,
428            })
429        }
430
431        fn get_enzyme_path(sysroot: &Sysroot) -> Result<String, EnzymeLibraryError> {
432            let llvm_version_major = unsafe { LLVMRustVersionMajor() };
433
434            let path_buf = sysroot
435                .all_paths()
436                .map(|sysroot_path| {
437                    filesearch::make_target_lib_path(sysroot_path, host_tuple())
438                        .join("lib")
439                        .with_file_name(format!("libEnzyme-{llvm_version_major}"))
440                        .with_extension(std::env::consts::DLL_EXTENSION)
441                })
442                .find(|f| f.exists())
443                .ok_or_else(|| {
444                    let candidates = sysroot
445                        .all_paths()
446                        .map(|p| p.join("lib").display().to_string())
447                        .collect::<Vec<String>>()
448                        .join("\n* ");
449                    EnzymeLibraryError::NotFound {
450                        err: format!(
451                            "failed to find a `libEnzyme-{llvm_version_major}` folder \
452                    in the sysroot candidates:\n* {candidates}"
453                        ),
454                    }
455                })?;
456
457            Ok(path_buf
458                .to_str()
459                .ok_or_else(|| EnzymeLibraryError::LoadFailed {
460                    err: format!("invalid UTF-8 in path: {}", path_buf.display()),
461                })?
462                .to_string())
463        }
464    }
465}
466
467impl TypeTree {
468    pub(crate) fn new() -> TypeTree {
469        let wrapper = EnzymeWrapper::get_instance();
470        let inner = wrapper.new_type_tree();
471        TypeTree { inner }
472    }
473
474    pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
475        let wrapper = EnzymeWrapper::get_instance();
476        let inner = wrapper.new_type_tree_ct(t, ctx);
477        TypeTree { inner }
478    }
479
480    pub(crate) fn merge(self, other: Self) -> Self {
481        let wrapper = EnzymeWrapper::get_instance();
482        wrapper.merge_type_tree(self.inner, other.inner);
483        drop(other);
484        self
485    }
486
487    #[must_use]
488    pub(crate) fn shift(
489        self,
490        layout: &str,
491        offset: isize,
492        max_size: isize,
493        add_offset: usize,
494    ) -> Self {
495        let layout = std::ffi::CString::new(layout).unwrap();
496        let wrapper = EnzymeWrapper::get_instance();
497        wrapper.shift_indicies_eq(
498            self.inner,
499            layout.as_ptr(),
500            offset as i64,
501            max_size as i64,
502            add_offset as u64,
503        );
504
505        self
506    }
507
508    pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
509        let wrapper = EnzymeWrapper::get_instance();
510        wrapper.tree_insert_eq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
511    }
512}
513
514impl Clone for TypeTree {
515    fn clone(&self) -> Self {
516        let wrapper = EnzymeWrapper::get_instance();
517        let inner = wrapper.new_type_tree_tr(self.inner);
518        TypeTree { inner }
519    }
520}
521
522impl std::fmt::Display for TypeTree {
523    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524        let wrapper = EnzymeWrapper::get_instance();
525        let ptr = wrapper.tree_to_string(self.inner);
526        let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
527        match cstr.to_str() {
528            Ok(x) => write!(f, "{}", x)?,
529            Err(err) => write!(f, "could not parse: {}", err)?,
530        }
531
532        // delete C string pointer
533        wrapper.tree_to_string_free(ptr);
534
535        Ok(())
536    }
537}
538
539impl std::fmt::Debug for TypeTree {
540    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541        <Self as std::fmt::Display>::fmt(self, f)
542    }
543}
544
545impl Drop for TypeTree {
546    fn drop(&mut self) {
547        let wrapper = EnzymeWrapper::get_instance();
548        wrapper.free_type_tree(self.inner)
549    }
550}