1#![expect(dead_code)]
2
3use libc::{c_char, c_uint};
4
5use super::MetadataKindId;
6use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
7use crate::llvm::{Bool, Builder};
8
9pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
11
12#[repr(C)]
13#[derive(Debug, Copy, Clone)]
14pub(crate) struct EnzymeTypeTree {
15 _unused: [u8; 0],
16}
17
18#[repr(u32)]
19#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
20#[allow(non_camel_case_types)]
21pub(crate) enum CConcreteType {
22 DT_Anything = 0,
23 DT_Integer = 1,
24 DT_Pointer = 2,
25 DT_Half = 3,
26 DT_Float = 4,
27 DT_Double = 5,
28 DT_Unknown = 6,
29 DT_FP128 = 9,
30}
31
32pub(crate) struct TypeTree {
33 pub(crate) inner: CTypeTreeRef,
34}
35
36#[link(name = "llvm-wrapper", kind = "static")]
37unsafe extern "C" {
38 pub(crate) safe fn LLVMRustHasMetadata(I: &Value, KindID: MetadataKindId) -> bool;
40 pub(crate) fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
41 pub(crate) fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
42 pub(crate) fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
43 pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
44 pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
45 pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
46 pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
47 pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
48 pub(crate) fn LLVMRustHasFnAttribute(
49 F: &Value,
50 Name: *const c_char,
51 NameLen: libc::size_t,
52 ) -> bool;
53 pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
54 pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
55 pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
56 pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
57 Fn: &Value,
58 index: c_uint,
59 kind: AttributeKind,
60 );
61 pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
62 pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
63 pub(crate) fn LLVMRustGetFunctionCall(
64 F: &Value,
65 name: *const c_char,
66 NameLen: libc::size_t,
67 ) -> Option<&Value>;
68
69}
70
71unsafe extern "C" {
72 pub(crate) fn LLVMDumpModule(M: &Module);
74 pub(crate) fn LLVMDumpValue(V: &Value);
75 pub(crate) fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
76 pub(crate) fn LLVMGetReturnType(T: &Type) -> &Type;
77 pub(crate) fn LLVMGetParams(Fnc: &Value, params: *mut &Value);
78 pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
79}
80
81#[repr(C)]
82#[derive(Copy, Clone, PartialEq)]
83pub(crate) enum LLVMRustVerifierFailureAction {
84 LLVMAbortProcessAction = 0,
85 LLVMPrintMessageAction = 1,
86 LLVMReturnStatusAction = 2,
87}
88
89pub(crate) use self::Enzyme_AD::*;
90
91pub(crate) mod Enzyme_AD {
92 use std::ffi::{c_char, c_void};
93 use std::sync::{Mutex, MutexGuard, OnceLock};
94
95 use rustc_middle::bug;
96 use rustc_session::config::{Sysroot, host_tuple};
97 use rustc_session::filesearch;
98
99 use super::{CConcreteType, CTypeTreeRef, Context};
100 use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor};
101
102 type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8);
103 type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char);
104
105 type EnzymeNewTypeTreeFn = unsafe extern "C" fn() -> CTypeTreeRef;
106 type EnzymeNewTypeTreeCTFn = unsafe extern "C" fn(CConcreteType, &Context) -> CTypeTreeRef;
107 type EnzymeNewTypeTreeTRFn = unsafe extern "C" fn(CTypeTreeRef) -> CTypeTreeRef;
108 type EnzymeFreeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef);
109 type EnzymeMergeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef, CTypeTreeRef) -> bool;
110 type EnzymeTypeTreeOnlyEqFn = unsafe extern "C" fn(CTypeTreeRef, i64);
111 type EnzymeTypeTreeData0EqFn = unsafe extern "C" fn(CTypeTreeRef);
112 type EnzymeTypeTreeShiftIndiciesEqFn =
113 unsafe extern "C" fn(CTypeTreeRef, *const c_char, i64, i64, u64);
114 type EnzymeTypeTreeInsertEqFn =
115 unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context);
116 type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char;
117 type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char);
118
119 #[allow(non_snake_case)]
120 pub(crate) struct EnzymeWrapper {
121 EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
122 EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
123 EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
124 EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
125 EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
126 EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
127 EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
128 EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
129 EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
130 EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
131 EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
132
133 EnzymePrintPerf: *mut c_void,
134 EnzymePrintActivity: *mut c_void,
135 EnzymePrintType: *mut c_void,
136 EnzymeFunctionToAnalyze: *mut c_void,
137 EnzymePrint: *mut c_void,
138 EnzymeStrictAliasing: *mut c_void,
139 EnzymeInline: *mut c_void,
140 EnzymeMaxTypeDepth: *mut c_void,
141 RustTypeRules: *mut c_void,
142 looseTypeAnalysis: *mut c_void,
143
144 EnzymeSetCLBool: EnzymeSetCLBoolFn,
145 EnzymeSetCLString: EnzymeSetCLStringFn,
146 pub registerEnzymeAndPassPipeline: *const c_void,
147 lib: libloading::Library,
148 }
149
150 unsafe impl Sync for EnzymeWrapper {}
151 unsafe impl Send for EnzymeWrapper {}
152
153 fn load_ptr_by_symbol_mut_void(
154 lib: &libloading::Library,
155 bytes: &[u8],
156 ) -> Result<*mut c_void, libloading::Error> {
157 unsafe {
158 let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
159 let s = s.try_as_raw_ptr().unwrap();
161 Ok(s)
162 }
163 }
164
165 macro_rules! load_ptrs_by_symbols_mut_void {
171 ($lib:expr, $($name:ident),* $(,)?) => {
172 $(
173 #[allow(non_snake_case)]
174 let $name = load_ptr_by_symbol_mut_void(&$lib, stringify!($name).as_bytes())?;
175 )*
176 };
177 }
178
179 macro_rules! load_ptrs_by_symbols_fn {
185 ($lib:expr, $($name:ident : $ty:ty),* $(,)?) => {
186 $(
187 #[allow(non_snake_case)]
188 let $name: $ty = *unsafe { $lib.get::<$ty>(stringify!($name).as_bytes())? };
189 )*
190 };
191 }
192
193 static ENZYME_INSTANCE: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
194
195 #[derive(Debug)]
196 pub(crate) enum EnzymeLibraryError {
197 NotFound { err: String },
198 LoadFailed { err: String },
199 }
200
201 impl From<libloading::Error> for EnzymeLibraryError {
202 fn from(err: libloading::Error) -> Self {
203 Self::LoadFailed { err: format!("{err:?}") }
204 }
205 }
206
207 impl EnzymeWrapper {
208 pub(crate) fn get_or_init(
211 sysroot: &rustc_session::config::Sysroot,
212 ) -> Result<MutexGuard<'static, Self>, EnzymeLibraryError> {
213 let mtx: &'static Mutex<EnzymeWrapper> = ENZYME_INSTANCE.get_or_try_init(|| {
214 let w = Self::call_dynamic(sysroot)?;
215 Ok::<_, EnzymeLibraryError>(Mutex::new(w))
216 })?;
217
218 Ok(mtx.lock().unwrap())
219 }
220
221 pub(crate) fn get_instance() -> MutexGuard<'static, Self> {
223 ENZYME_INSTANCE
224 .get()
225 .expect("EnzymeWrapper not initialized. Call get_or_init with sysroot first.")
226 .lock()
227 .unwrap()
228 }
229
230 pub(crate) fn new_type_tree(&self) -> CTypeTreeRef {
231 unsafe { (self.EnzymeNewTypeTree)() }
232 }
233
234 pub(crate) fn new_type_tree_ct(
235 &self,
236 t: CConcreteType,
237 ctx: &Context,
238 ) -> *mut EnzymeTypeTree {
239 unsafe { (self.EnzymeNewTypeTreeCT)(t, ctx) }
240 }
241
242 pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef {
243 unsafe { (self.EnzymeNewTypeTreeTR)(tree) }
244 }
245
246 pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) {
247 unsafe { (self.EnzymeFreeTypeTree)(tree) }
248 }
249
250 pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool {
251 unsafe { (self.EnzymeMergeTypeTree)(tree1, tree2) }
252 }
253
254 pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) {
255 unsafe { (self.EnzymeTypeTreeOnlyEq)(tree, num) }
256 }
257
258 pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) {
259 unsafe { (self.EnzymeTypeTreeData0Eq)(tree) }
260 }
261
262 pub(crate) fn shift_indicies_eq(
263 &self,
264 tree: CTypeTreeRef,
265 data_layout: *const c_char,
266 offset: i64,
267 max_size: i64,
268 add_offset: u64,
269 ) {
270 unsafe {
271 (self.EnzymeTypeTreeShiftIndiciesEq)(
272 tree,
273 data_layout,
274 offset,
275 max_size,
276 add_offset,
277 )
278 }
279 }
280
281 pub(crate) fn tree_insert_eq(
282 &self,
283 tree: CTypeTreeRef,
284 indices: *const i64,
285 len: usize,
286 ct: CConcreteType,
287 ctx: &Context,
288 ) {
289 unsafe { (self.EnzymeTypeTreeInsertEq)(tree, indices, len, ct, ctx) }
290 }
291
292 pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char {
293 unsafe { (self.EnzymeTypeTreeToString)(tree) }
294 }
295
296 pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
297 unsafe { (self.EnzymeTypeTreeToStringFree)(ch) }
298 }
299
300 pub(crate) fn get_max_type_depth(&self) -> usize {
301 unsafe { std::ptr::read::<u32>(self.EnzymeMaxTypeDepth as *const u32) as usize }
302 }
303
304 pub(crate) fn set_print_perf(&mut self, print: bool) {
305 unsafe {
306 (self.EnzymeSetCLBool)(self.EnzymePrintPerf, print as u8);
307 }
308 }
309
310 pub(crate) fn set_print_activity(&mut self, print: bool) {
311 unsafe {
312 (self.EnzymeSetCLBool)(self.EnzymePrintActivity, print as u8);
313 }
314 }
315
316 pub(crate) fn set_print_type(&mut self, print: bool) {
317 unsafe {
318 (self.EnzymeSetCLBool)(self.EnzymePrintType, print as u8);
319 }
320 }
321
322 pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
323 let c_fun_name = std::ffi::CString::new(fun_name)
324 .unwrap_or_else(|err| bug!("failed to set_print_type_fun: {err}"));
325 unsafe {
326 (self.EnzymeSetCLString)(
327 self.EnzymeFunctionToAnalyze,
328 c_fun_name.as_ptr() as *const c_char,
329 );
330 }
331 }
332
333 pub(crate) fn set_print(&mut self, print: bool) {
334 unsafe {
335 (self.EnzymeSetCLBool)(self.EnzymePrint, print as u8);
336 }
337 }
338
339 pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
340 unsafe {
341 (self.EnzymeSetCLBool)(self.EnzymeStrictAliasing, strict as u8);
342 }
343 }
344
345 pub(crate) fn set_loose_types(&mut self, loose: bool) {
346 unsafe {
347 (self.EnzymeSetCLBool)(self.looseTypeAnalysis, loose as u8);
348 }
349 }
350
351 pub(crate) fn set_inline(&mut self, val: bool) {
352 unsafe {
353 (self.EnzymeSetCLBool)(self.EnzymeInline, val as u8);
354 }
355 }
356
357 pub(crate) fn set_rust_rules(&mut self, val: bool) {
358 unsafe {
359 (self.EnzymeSetCLBool)(self.RustTypeRules, val as u8);
360 }
361 }
362
363 #[allow(non_snake_case)]
364 fn call_dynamic(
365 sysroot: &rustc_session::config::Sysroot,
366 ) -> Result<Self, EnzymeLibraryError> {
367 let enzyme_path = Self::get_enzyme_path(sysroot)?;
368 let lib = unsafe { libloading::Library::new(enzyme_path)? };
369
370 load_ptrs_by_symbols_fn!(
371 lib,
372 EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
373 EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
374 EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
375 EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
376 EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
377 EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
378 EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
379 EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
380 EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
381 EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
382 EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
383 EnzymeSetCLBool: EnzymeSetCLBoolFn,
384 EnzymeSetCLString: EnzymeSetCLStringFn,
385 );
386
387 load_ptrs_by_symbols_mut_void!(
388 lib,
389 registerEnzymeAndPassPipeline,
390 EnzymePrintPerf,
391 EnzymePrintActivity,
392 EnzymePrintType,
393 EnzymeFunctionToAnalyze,
394 EnzymePrint,
395 EnzymeStrictAliasing,
396 EnzymeInline,
397 EnzymeMaxTypeDepth,
398 RustTypeRules,
399 looseTypeAnalysis,
400 );
401
402 Ok(Self {
403 EnzymeNewTypeTree,
404 EnzymeNewTypeTreeCT,
405 EnzymeNewTypeTreeTR,
406 EnzymeFreeTypeTree,
407 EnzymeMergeTypeTree,
408 EnzymeTypeTreeOnlyEq,
409 EnzymeTypeTreeData0Eq,
410 EnzymeTypeTreeShiftIndiciesEq,
411 EnzymeTypeTreeInsertEq,
412 EnzymeTypeTreeToString,
413 EnzymeTypeTreeToStringFree,
414 EnzymePrintPerf,
415 EnzymePrintActivity,
416 EnzymePrintType,
417 EnzymeFunctionToAnalyze,
418 EnzymePrint,
419 EnzymeStrictAliasing,
420 EnzymeInline,
421 EnzymeMaxTypeDepth,
422 RustTypeRules,
423 looseTypeAnalysis,
424 EnzymeSetCLBool,
425 EnzymeSetCLString,
426 registerEnzymeAndPassPipeline,
427 lib,
428 })
429 }
430
431 fn get_enzyme_path(sysroot: &Sysroot) -> Result<String, EnzymeLibraryError> {
432 let llvm_version_major = unsafe { LLVMRustVersionMajor() };
433
434 let path_buf = sysroot
435 .all_paths()
436 .map(|sysroot_path| {
437 filesearch::make_target_lib_path(sysroot_path, host_tuple())
438 .join("lib")
439 .with_file_name(format!("libEnzyme-{llvm_version_major}"))
440 .with_extension(std::env::consts::DLL_EXTENSION)
441 })
442 .find(|f| f.exists())
443 .ok_or_else(|| {
444 let candidates = sysroot
445 .all_paths()
446 .map(|p| p.join("lib").display().to_string())
447 .collect::<Vec<String>>()
448 .join("\n* ");
449 EnzymeLibraryError::NotFound {
450 err: format!(
451 "failed to find a `libEnzyme-{llvm_version_major}` folder \
452 in the sysroot candidates:\n* {candidates}"
453 ),
454 }
455 })?;
456
457 Ok(path_buf
458 .to_str()
459 .ok_or_else(|| EnzymeLibraryError::LoadFailed {
460 err: format!("invalid UTF-8 in path: {}", path_buf.display()),
461 })?
462 .to_string())
463 }
464 }
465}
466
467impl TypeTree {
468 pub(crate) fn new() -> TypeTree {
469 let wrapper = EnzymeWrapper::get_instance();
470 let inner = wrapper.new_type_tree();
471 TypeTree { inner }
472 }
473
474 pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
475 let wrapper = EnzymeWrapper::get_instance();
476 let inner = wrapper.new_type_tree_ct(t, ctx);
477 TypeTree { inner }
478 }
479
480 pub(crate) fn merge(self, other: Self) -> Self {
481 let wrapper = EnzymeWrapper::get_instance();
482 wrapper.merge_type_tree(self.inner, other.inner);
483 drop(other);
484 self
485 }
486
487 #[must_use]
488 pub(crate) fn shift(
489 self,
490 layout: &str,
491 offset: isize,
492 max_size: isize,
493 add_offset: usize,
494 ) -> Self {
495 let layout = std::ffi::CString::new(layout).unwrap();
496 let wrapper = EnzymeWrapper::get_instance();
497 wrapper.shift_indicies_eq(
498 self.inner,
499 layout.as_ptr(),
500 offset as i64,
501 max_size as i64,
502 add_offset as u64,
503 );
504
505 self
506 }
507
508 pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
509 let wrapper = EnzymeWrapper::get_instance();
510 wrapper.tree_insert_eq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
511 }
512}
513
514impl Clone for TypeTree {
515 fn clone(&self) -> Self {
516 let wrapper = EnzymeWrapper::get_instance();
517 let inner = wrapper.new_type_tree_tr(self.inner);
518 TypeTree { inner }
519 }
520}
521
522impl std::fmt::Display for TypeTree {
523 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
524 let wrapper = EnzymeWrapper::get_instance();
525 let ptr = wrapper.tree_to_string(self.inner);
526 let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
527 match cstr.to_str() {
528 Ok(x) => write!(f, "{}", x)?,
529 Err(err) => write!(f, "could not parse: {}", err)?,
530 }
531
532 wrapper.tree_to_string_free(ptr);
534
535 Ok(())
536 }
537}
538
539impl std::fmt::Debug for TypeTree {
540 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
541 <Self as std::fmt::Display>::fmt(self, f)
542 }
543}
544
545impl Drop for TypeTree {
546 fn drop(&mut self) {
547 let wrapper = EnzymeWrapper::get_instance();
548 wrapper.free_type_tree(self.inner)
549 }
550}