1#![expect(dead_code)]
2
3use libc::{c_char, c_uint};
4
5use super::MetadataKindId;
6use super::ffi::{AttributeKind, BasicBlock, Context, Metadata, Module, Type, Value};
7use crate::llvm::{Bool, Builder};
8
9pub(crate) type CTypeTreeRef = *mut EnzymeTypeTree;
11
12#[repr(C)]
13#[derive(Debug, Copy, Clone)]
14pub(crate) struct EnzymeTypeTree {
15 _unused: [u8; 0],
16}
17
18#[repr(u32)]
19#[derive(Debug, Copy, Clone, Hash, PartialEq, Eq)]
20#[allow(non_camel_case_types)]
21pub(crate) enum CConcreteType {
22 DT_Anything = 0,
23 DT_Integer = 1,
24 DT_Pointer = 2,
25 DT_Half = 3,
26 DT_Float = 4,
27 DT_Double = 5,
28 DT_Unknown = 6,
29 DT_FP128 = 9,
30}
31
32pub(crate) struct TypeTree {
33 pub(crate) inner: CTypeTreeRef,
34}
35
36#[link(name = "llvm-wrapper", kind = "static")]
37unsafe extern "C" {
38 pub(crate) safe fn LLVMRustHasMetadata(I: &Value, KindID: MetadataKindId) -> bool;
40 pub(crate) fn LLVMRustEraseInstUntilInclusive(BB: &BasicBlock, I: &Value);
41 pub(crate) fn LLVMRustGetLastInstruction<'a>(BB: &BasicBlock) -> Option<&'a Value>;
42 pub(crate) fn LLVMRustDIGetInstMetadata(I: &Value) -> Option<&Metadata>;
43 pub(crate) fn LLVMRustEraseInstFromParent(V: &Value);
44 pub(crate) fn LLVMRustGetTerminator<'a>(B: &BasicBlock) -> &'a Value;
45 pub(crate) fn LLVMRustVerifyFunction(V: &Value, action: LLVMRustVerifierFailureAction) -> Bool;
46 pub(crate) fn LLVMRustHasAttributeAtIndex(V: &Value, i: c_uint, Kind: AttributeKind) -> bool;
47 pub(crate) fn LLVMRustGetArrayNumElements(Ty: &Type) -> u64;
48 pub(crate) fn LLVMRustHasFnAttribute(
49 F: &Value,
50 Name: *const c_char,
51 NameLen: libc::size_t,
52 ) -> bool;
53 pub(crate) fn LLVMRustRemoveFnAttribute(F: &Value, Name: *const c_char, NameLen: libc::size_t);
54 pub(crate) fn LLVMGetFirstFunction(M: &Module) -> Option<&Value>;
55 pub(crate) fn LLVMGetNextFunction(Fn: &Value) -> Option<&Value>;
56 pub(crate) fn LLVMRustRemoveEnumAttributeAtIndex(
57 Fn: &Value,
58 index: c_uint,
59 kind: AttributeKind,
60 );
61 pub(crate) fn LLVMRustPositionBefore<'a>(B: &'a Builder<'_>, I: &'a Value);
62 pub(crate) fn LLVMRustPositionAfter<'a>(B: &'a Builder<'_>, I: &'a Value);
63 pub(crate) fn LLVMRustGetFunctionCall(
64 F: &Value,
65 name: *const c_char,
66 NameLen: libc::size_t,
67 ) -> Option<&Value>;
68
69}
70
71unsafe extern "C" {
72 pub(crate) fn LLVMDumpModule(M: &Module);
74 pub(crate) fn LLVMDumpValue(V: &Value);
75 pub(crate) fn LLVMGetFunctionCallConv(F: &Value) -> c_uint;
76 pub(crate) fn LLVMGetReturnType(T: &Type) -> &Type;
77 pub(crate) fn LLVMGetParams(Fnc: &Value, params: *mut &Value);
78 pub(crate) fn LLVMGetNamedFunction(M: &Module, Name: *const c_char) -> Option<&Value>;
79}
80
81#[repr(C)]
82#[derive(Copy, Clone, PartialEq)]
83pub(crate) enum LLVMRustVerifierFailureAction {
84 LLVMAbortProcessAction = 0,
85 LLVMPrintMessageAction = 1,
86 LLVMReturnStatusAction = 2,
87}
88
89pub(crate) use self::Enzyme_AD::*;
90
91pub(crate) mod Enzyme_AD {
92 use std::ffi::{c_char, c_void};
93 use std::sync::{Mutex, MutexGuard, OnceLock};
94
95 use rustc_middle::bug;
96 use rustc_session::config::{Sysroot, host_tuple};
97 use rustc_session::filesearch;
98
99 use super::{CConcreteType, CTypeTreeRef, Context};
100 use crate::llvm::{EnzymeTypeTree, LLVMRustVersionMajor};
101
102 type EnzymeSetCLBoolFn = unsafe extern "C" fn(*mut c_void, u8);
103 type EnzymeSetCLStringFn = unsafe extern "C" fn(*mut c_void, *const c_char);
104
105 type EnzymeNewTypeTreeFn = unsafe extern "C" fn() -> CTypeTreeRef;
106 type EnzymeNewTypeTreeCTFn = unsafe extern "C" fn(CConcreteType, &Context) -> CTypeTreeRef;
107 type EnzymeNewTypeTreeTRFn = unsafe extern "C" fn(CTypeTreeRef) -> CTypeTreeRef;
108 type EnzymeFreeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef);
109 type EnzymeMergeTypeTreeFn = unsafe extern "C" fn(CTypeTreeRef, CTypeTreeRef) -> bool;
110 type EnzymeTypeTreeOnlyEqFn = unsafe extern "C" fn(CTypeTreeRef, i64);
111 type EnzymeTypeTreeData0EqFn = unsafe extern "C" fn(CTypeTreeRef);
112 type EnzymeTypeTreeShiftIndiciesEqFn =
113 unsafe extern "C" fn(CTypeTreeRef, *const c_char, i64, i64, u64);
114 type EnzymeTypeTreeInsertEqFn =
115 unsafe extern "C" fn(CTypeTreeRef, *const i64, usize, CConcreteType, &Context);
116 type EnzymeTypeTreeToStringFn = unsafe extern "C" fn(CTypeTreeRef) -> *const c_char;
117 type EnzymeTypeTreeToStringFreeFn = unsafe extern "C" fn(*const c_char);
118
119 #[allow(non_snake_case)]
120 pub(crate) struct EnzymeWrapper {
121 EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
122 EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
123 EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
124 EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
125 EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
126 EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
127 EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
128 EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
129 EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
130 EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
131 EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
132
133 EnzymePrintPerf: *mut c_void,
134 EnzymePrintActivity: *mut c_void,
135 EnzymePrintType: *mut c_void,
136 EnzymeFunctionToAnalyze: *mut c_void,
137 EnzymePrint: *mut c_void,
138 EnzymeStrictAliasing: *mut c_void,
139 EnzymeInline: *mut c_void,
140 EnzymeMaxTypeDepth: *mut c_void,
141 RustTypeRules: *mut c_void,
142 looseTypeAnalysis: *mut c_void,
143
144 EnzymeSetCLBool: EnzymeSetCLBoolFn,
145 EnzymeSetCLString: EnzymeSetCLStringFn,
146 pub registerEnzymeAndPassPipeline: *const c_void,
147 lib: libloading::Library,
148 }
149
150 unsafe impl Sync for EnzymeWrapper {}
151 unsafe impl Send for EnzymeWrapper {}
152
153 fn load_ptr_by_symbol_mut_void(
154 lib: &libloading::Library,
155 bytes: &[u8],
156 ) -> Result<*mut c_void, Box<dyn std::error::Error>> {
157 unsafe {
158 let s: libloading::Symbol<'_, *mut c_void> = lib.get(bytes)?;
159 let s = s.try_as_raw_ptr().unwrap();
161 Ok(s)
162 }
163 }
164
165 macro_rules! load_ptrs_by_symbols_mut_void {
171 ($lib:expr, $($name:ident),* $(,)?) => {
172 $(
173 #[allow(non_snake_case)]
174 let $name = load_ptr_by_symbol_mut_void(&$lib, stringify!($name).as_bytes())?;
175 )*
176 };
177 }
178
179 macro_rules! load_ptrs_by_symbols_fn {
185 ($lib:expr, $($name:ident : $ty:ty),* $(,)?) => {
186 $(
187 #[allow(non_snake_case)]
188 let $name: $ty = *unsafe { $lib.get::<$ty>(stringify!($name).as_bytes())? };
189 )*
190 };
191 }
192
193 static ENZYME_INSTANCE: OnceLock<Mutex<EnzymeWrapper>> = OnceLock::new();
194
195 impl EnzymeWrapper {
196 pub(crate) fn get_or_init(
199 sysroot: &rustc_session::config::Sysroot,
200 ) -> Result<MutexGuard<'static, Self>, Box<dyn std::error::Error>> {
201 let mtx: &'static Mutex<EnzymeWrapper> = ENZYME_INSTANCE.get_or_try_init(|| {
202 let w = Self::call_dynamic(sysroot)?;
203 Ok::<_, Box<dyn std::error::Error>>(Mutex::new(w))
204 })?;
205
206 Ok(mtx.lock().unwrap())
207 }
208
209 pub(crate) fn get_instance() -> MutexGuard<'static, Self> {
211 ENZYME_INSTANCE
212 .get()
213 .expect("EnzymeWrapper not initialized. Call get_or_init with sysroot first.")
214 .lock()
215 .unwrap()
216 }
217
218 pub(crate) fn new_type_tree(&self) -> CTypeTreeRef {
219 unsafe { (self.EnzymeNewTypeTree)() }
220 }
221
222 pub(crate) fn new_type_tree_ct(
223 &self,
224 t: CConcreteType,
225 ctx: &Context,
226 ) -> *mut EnzymeTypeTree {
227 unsafe { (self.EnzymeNewTypeTreeCT)(t, ctx) }
228 }
229
230 pub(crate) fn new_type_tree_tr(&self, tree: CTypeTreeRef) -> CTypeTreeRef {
231 unsafe { (self.EnzymeNewTypeTreeTR)(tree) }
232 }
233
234 pub(crate) fn free_type_tree(&self, tree: CTypeTreeRef) {
235 unsafe { (self.EnzymeFreeTypeTree)(tree) }
236 }
237
238 pub(crate) fn merge_type_tree(&self, tree1: CTypeTreeRef, tree2: CTypeTreeRef) -> bool {
239 unsafe { (self.EnzymeMergeTypeTree)(tree1, tree2) }
240 }
241
242 pub(crate) fn tree_only_eq(&self, tree: CTypeTreeRef, num: i64) {
243 unsafe { (self.EnzymeTypeTreeOnlyEq)(tree, num) }
244 }
245
246 pub(crate) fn tree_data0_eq(&self, tree: CTypeTreeRef) {
247 unsafe { (self.EnzymeTypeTreeData0Eq)(tree) }
248 }
249
250 pub(crate) fn shift_indicies_eq(
251 &self,
252 tree: CTypeTreeRef,
253 data_layout: *const c_char,
254 offset: i64,
255 max_size: i64,
256 add_offset: u64,
257 ) {
258 unsafe {
259 (self.EnzymeTypeTreeShiftIndiciesEq)(
260 tree,
261 data_layout,
262 offset,
263 max_size,
264 add_offset,
265 )
266 }
267 }
268
269 pub(crate) fn tree_insert_eq(
270 &self,
271 tree: CTypeTreeRef,
272 indices: *const i64,
273 len: usize,
274 ct: CConcreteType,
275 ctx: &Context,
276 ) {
277 unsafe { (self.EnzymeTypeTreeInsertEq)(tree, indices, len, ct, ctx) }
278 }
279
280 pub(crate) fn tree_to_string(&self, tree: *mut EnzymeTypeTree) -> *const c_char {
281 unsafe { (self.EnzymeTypeTreeToString)(tree) }
282 }
283
284 pub(crate) fn tree_to_string_free(&self, ch: *const c_char) {
285 unsafe { (self.EnzymeTypeTreeToStringFree)(ch) }
286 }
287
288 pub(crate) fn get_max_type_depth(&self) -> usize {
289 unsafe { std::ptr::read::<u32>(self.EnzymeMaxTypeDepth as *const u32) as usize }
290 }
291
292 pub(crate) fn set_print_perf(&mut self, print: bool) {
293 unsafe {
294 (self.EnzymeSetCLBool)(self.EnzymePrintPerf, print as u8);
295 }
296 }
297
298 pub(crate) fn set_print_activity(&mut self, print: bool) {
299 unsafe {
300 (self.EnzymeSetCLBool)(self.EnzymePrintActivity, print as u8);
301 }
302 }
303
304 pub(crate) fn set_print_type(&mut self, print: bool) {
305 unsafe {
306 (self.EnzymeSetCLBool)(self.EnzymePrintType, print as u8);
307 }
308 }
309
310 pub(crate) fn set_print_type_fun(&mut self, fun_name: &str) {
311 let c_fun_name = std::ffi::CString::new(fun_name)
312 .unwrap_or_else(|err| bug!("failed to set_print_type_fun: {err}"));
313 unsafe {
314 (self.EnzymeSetCLString)(
315 self.EnzymeFunctionToAnalyze,
316 c_fun_name.as_ptr() as *const c_char,
317 );
318 }
319 }
320
321 pub(crate) fn set_print(&mut self, print: bool) {
322 unsafe {
323 (self.EnzymeSetCLBool)(self.EnzymePrint, print as u8);
324 }
325 }
326
327 pub(crate) fn set_strict_aliasing(&mut self, strict: bool) {
328 unsafe {
329 (self.EnzymeSetCLBool)(self.EnzymeStrictAliasing, strict as u8);
330 }
331 }
332
333 pub(crate) fn set_loose_types(&mut self, loose: bool) {
334 unsafe {
335 (self.EnzymeSetCLBool)(self.looseTypeAnalysis, loose as u8);
336 }
337 }
338
339 pub(crate) fn set_inline(&mut self, val: bool) {
340 unsafe {
341 (self.EnzymeSetCLBool)(self.EnzymeInline, val as u8);
342 }
343 }
344
345 pub(crate) fn set_rust_rules(&mut self, val: bool) {
346 unsafe {
347 (self.EnzymeSetCLBool)(self.RustTypeRules, val as u8);
348 }
349 }
350
351 #[allow(non_snake_case)]
352 fn call_dynamic(
353 sysroot: &rustc_session::config::Sysroot,
354 ) -> Result<Self, Box<dyn std::error::Error>> {
355 let enzyme_path = Self::get_enzyme_path(sysroot)?;
356 let lib = unsafe { libloading::Library::new(enzyme_path)? };
357
358 load_ptrs_by_symbols_fn!(
359 lib,
360 EnzymeNewTypeTree: EnzymeNewTypeTreeFn,
361 EnzymeNewTypeTreeCT: EnzymeNewTypeTreeCTFn,
362 EnzymeNewTypeTreeTR: EnzymeNewTypeTreeTRFn,
363 EnzymeFreeTypeTree: EnzymeFreeTypeTreeFn,
364 EnzymeMergeTypeTree: EnzymeMergeTypeTreeFn,
365 EnzymeTypeTreeOnlyEq: EnzymeTypeTreeOnlyEqFn,
366 EnzymeTypeTreeData0Eq: EnzymeTypeTreeData0EqFn,
367 EnzymeTypeTreeShiftIndiciesEq: EnzymeTypeTreeShiftIndiciesEqFn,
368 EnzymeTypeTreeInsertEq: EnzymeTypeTreeInsertEqFn,
369 EnzymeTypeTreeToString: EnzymeTypeTreeToStringFn,
370 EnzymeTypeTreeToStringFree: EnzymeTypeTreeToStringFreeFn,
371 EnzymeSetCLBool: EnzymeSetCLBoolFn,
372 EnzymeSetCLString: EnzymeSetCLStringFn,
373 );
374
375 load_ptrs_by_symbols_mut_void!(
376 lib,
377 registerEnzymeAndPassPipeline,
378 EnzymePrintPerf,
379 EnzymePrintActivity,
380 EnzymePrintType,
381 EnzymeFunctionToAnalyze,
382 EnzymePrint,
383 EnzymeStrictAliasing,
384 EnzymeInline,
385 EnzymeMaxTypeDepth,
386 RustTypeRules,
387 looseTypeAnalysis,
388 );
389
390 Ok(Self {
391 EnzymeNewTypeTree,
392 EnzymeNewTypeTreeCT,
393 EnzymeNewTypeTreeTR,
394 EnzymeFreeTypeTree,
395 EnzymeMergeTypeTree,
396 EnzymeTypeTreeOnlyEq,
397 EnzymeTypeTreeData0Eq,
398 EnzymeTypeTreeShiftIndiciesEq,
399 EnzymeTypeTreeInsertEq,
400 EnzymeTypeTreeToString,
401 EnzymeTypeTreeToStringFree,
402 EnzymePrintPerf,
403 EnzymePrintActivity,
404 EnzymePrintType,
405 EnzymeFunctionToAnalyze,
406 EnzymePrint,
407 EnzymeStrictAliasing,
408 EnzymeInline,
409 EnzymeMaxTypeDepth,
410 RustTypeRules,
411 looseTypeAnalysis,
412 EnzymeSetCLBool,
413 EnzymeSetCLString,
414 registerEnzymeAndPassPipeline,
415 lib,
416 })
417 }
418
419 fn get_enzyme_path(sysroot: &Sysroot) -> Result<String, String> {
420 let llvm_version_major = unsafe { LLVMRustVersionMajor() };
421
422 let path_buf = sysroot
423 .all_paths()
424 .map(|sysroot_path| {
425 filesearch::make_target_lib_path(sysroot_path, host_tuple())
426 .join("lib")
427 .with_file_name(format!("libEnzyme-{llvm_version_major}"))
428 .with_extension(std::env::consts::DLL_EXTENSION)
429 })
430 .find(|f| f.exists())
431 .ok_or_else(|| {
432 let candidates = sysroot
433 .all_paths()
434 .map(|p| p.join("lib").display().to_string())
435 .collect::<Vec<String>>()
436 .join("\n* ");
437 format!(
438 "failed to find a `libEnzyme-{llvm_version_major}` folder \
439 in the sysroot candidates:\n* {candidates}"
440 )
441 })?;
442
443 Ok(path_buf
444 .to_str()
445 .ok_or_else(|| format!("invalid UTF-8 in path: {}", path_buf.display()))?
446 .to_string())
447 }
448 }
449}
450
451impl TypeTree {
452 pub(crate) fn new() -> TypeTree {
453 let wrapper = EnzymeWrapper::get_instance();
454 let inner = wrapper.new_type_tree();
455 TypeTree { inner }
456 }
457
458 pub(crate) fn from_type(t: CConcreteType, ctx: &Context) -> TypeTree {
459 let wrapper = EnzymeWrapper::get_instance();
460 let inner = wrapper.new_type_tree_ct(t, ctx);
461 TypeTree { inner }
462 }
463
464 pub(crate) fn merge(self, other: Self) -> Self {
465 let wrapper = EnzymeWrapper::get_instance();
466 wrapper.merge_type_tree(self.inner, other.inner);
467 drop(other);
468 self
469 }
470
471 #[must_use]
472 pub(crate) fn shift(
473 self,
474 layout: &str,
475 offset: isize,
476 max_size: isize,
477 add_offset: usize,
478 ) -> Self {
479 let layout = std::ffi::CString::new(layout).unwrap();
480 let wrapper = EnzymeWrapper::get_instance();
481 wrapper.shift_indicies_eq(
482 self.inner,
483 layout.as_ptr(),
484 offset as i64,
485 max_size as i64,
486 add_offset as u64,
487 );
488
489 self
490 }
491
492 pub(crate) fn insert(&mut self, indices: &[i64], ct: CConcreteType, ctx: &Context) {
493 let wrapper = EnzymeWrapper::get_instance();
494 wrapper.tree_insert_eq(self.inner, indices.as_ptr(), indices.len(), ct, ctx);
495 }
496}
497
498impl Clone for TypeTree {
499 fn clone(&self) -> Self {
500 let wrapper = EnzymeWrapper::get_instance();
501 let inner = wrapper.new_type_tree_tr(self.inner);
502 TypeTree { inner }
503 }
504}
505
506impl std::fmt::Display for TypeTree {
507 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
508 let wrapper = EnzymeWrapper::get_instance();
509 let ptr = wrapper.tree_to_string(self.inner);
510 let cstr = unsafe { std::ffi::CStr::from_ptr(ptr) };
511 match cstr.to_str() {
512 Ok(x) => write!(f, "{}", x)?,
513 Err(err) => write!(f, "could not parse: {}", err)?,
514 }
515
516 wrapper.tree_to_string_free(ptr);
518
519 Ok(())
520 }
521}
522
523impl std::fmt::Debug for TypeTree {
524 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
525 <Self as std::fmt::Display>::fmt(self, f)
526 }
527}
528
529impl Drop for TypeTree {
530 fn drop(&mut self) {
531 let wrapper = EnzymeWrapper::get_instance();
532 wrapper.free_type_tree(self.inner)
533 }
534}