Skip to main content

rustc_codegen_llvm/builder/
gpu_offload.rs

1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::common::TypeKind;
6use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
7use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
8use rustc_middle::bug;
9use rustc_middle::ty::offload_meta::OffloadMetadata;
10
11use crate::builder::Builder;
12use crate::common::CodegenCx;
13use crate::llvm::AttributePlace::Function;
14use crate::llvm::{self, Linkage, Type, Value};
15use crate::{SimpleCx, attributes};
16
17// LLVM kernel-independent globals required for offloading
18pub(crate) struct OffloadGlobals<'ll> {
19    pub launcher_fn: &'ll llvm::Value,
20    pub launcher_ty: &'ll llvm::Type,
21
22    pub kernel_args_ty: &'ll llvm::Type,
23
24    pub offload_entry_ty: &'ll llvm::Type,
25
26    pub begin_mapper: &'ll llvm::Value,
27    pub end_mapper: &'ll llvm::Value,
28    pub mapper_fn_ty: &'ll llvm::Type,
29
30    pub ident_t_global: &'ll llvm::Value,
31
32    // FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
33    // LLVM will initialize them for us if it sees gpu kernels being registered.
34    pub init_rtls: &'ll llvm::Value,
35}
36
37impl<'ll> OffloadGlobals<'ll> {
38    pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
39        let (launcher_fn, launcher_ty) = generate_launcher(cx);
40        let kernel_args_ty = KernelArgsTy::new_decl(cx);
41        let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
42        let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
43        let ident_t_global = generate_at_one(cx);
44
45        let init_ty = cx.type_func(&[], cx.type_void());
46        let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
47
48        // We want LLVM's openmp-opt pass to pick up and optimize this module, since it covers both
49        // openmp and offload optimizations.
50        llvm::add_module_flag_u32(cx.llmod(), llvm::ModuleFlagMergeBehavior::Max, "openmp", 51);
51
52        OffloadGlobals {
53            launcher_fn,
54            launcher_ty,
55            kernel_args_ty,
56            offload_entry_ty,
57            begin_mapper,
58            end_mapper,
59            mapper_fn_ty,
60            ident_t_global,
61            init_rtls,
62        }
63    }
64}
65
66// We need to register offload before using it. We also should unregister it once we are done, for
67// good measures. Previously we have done so before and after each individual offload intrinsic
68// call, but that comes at a performance cost. The repeated (un)register calls might also confuse
69// the LLVM ompOpt pass, which tries to move operations to a better location. The easiest solution,
70// which we copy from clang, is to just have those two calls once, in the global ctor/dtor section
71// of the final binary.
72pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
73    // First we check quickly whether we already have done our setup, in which case we return early.
74    // Shouldn't be needed for correctness.
75    let register_lib_name = "__tgt_register_lib";
76    if cx.get_function(register_lib_name).is_some() {
77        return;
78    }
79
80    let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
81    let register_lib = declare_offload_fn(&cx, register_lib_name, reg_lib_decl);
82    let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
83
84    let ptr_null = cx.const_null(cx.type_ptr());
85    let const_struct = cx.const_struct(&[cx.get_const_i32(0), ptr_null, ptr_null, ptr_null], false);
86    let omp_descriptor =
87        add_global(cx, ".omp_offloading.descriptor", const_struct, InternalLinkage);
88    // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 1, ptr @.omp_offloading.device_images, ptr @__start_llvm_offload_entries, ptr @__stop_llvm_offload_entries }
89    // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 0, ptr null, ptr null, ptr null }
90
91    let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
92    let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
93
94    let desc_ty = cx.type_func(&[], cx.type_void());
95    let reg_name = ".omp_offloading.descriptor_reg";
96    let unreg_name = ".omp_offloading.descriptor_unreg";
97    let desc_reg_fn = declare_offload_fn(cx, reg_name, desc_ty);
98    let desc_unreg_fn = declare_offload_fn(cx, unreg_name, desc_ty);
99    llvm::set_linkage(desc_reg_fn, InternalLinkage);
100    llvm::set_linkage(desc_unreg_fn, InternalLinkage);
101    llvm::set_section(desc_reg_fn, c".text.startup");
102    llvm::set_section(desc_unreg_fn, c".text.startup");
103
104    // define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
105    // entry:
106    //   call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
107    //   %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
108    //   ret void
109    // }
110    let bb = Builder::append_block(cx, desc_reg_fn, "entry");
111    let mut a = Builder::build(cx, bb);
112    a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
113    a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
114    a.ret_void();
115
116    // define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" {
117    // entry:
118    //   call void @__tgt_unregister_lib(ptr @.omp_offloading.descriptor)
119    //   ret void
120    // }
121    let bb = Builder::append_block(cx, desc_unreg_fn, "entry");
122    let mut a = Builder::build(cx, bb);
123    a.call(reg_lib_decl, None, None, unregister_lib, &[omp_descriptor], None, None);
124    a.ret_void();
125
126    // @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }]
127    let args = <[_]>::into_vec(::alloc::boxed::box_new([cx.get_const_i32(101), desc_reg_fn,
                ptr_null]))vec![cx.get_const_i32(101), desc_reg_fn, ptr_null];
128    let const_struct = cx.const_struct(&args, false);
129    let arr = cx.const_array(cx.val_ty(const_struct), &[const_struct]);
130    add_global(cx, "llvm.global_ctors", arr, AppendingLinkage);
131}
132
133pub(crate) struct OffloadKernelDims<'ll> {
134    num_workgroups: &'ll Value,
135    threads_per_block: &'ll Value,
136    workgroup_dims: &'ll Value,
137    thread_dims: &'ll Value,
138}
139
140impl<'ll> OffloadKernelDims<'ll> {
141    pub(crate) fn from_operands<'tcx>(
142        builder: &mut Builder<'_, 'll, 'tcx>,
143        workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
144        thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
145    ) -> Self {
146        let cx = builder.cx;
147        let arr_ty = cx.type_array(cx.type_i32(), 3);
148        let four = Align::from_bytes(4).unwrap();
149
150        let OperandValue::Ref(place) = workgroup_op.val else {
151            ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
152        };
153        let workgroup_val = builder.load(arr_ty, place.llval, four);
154
155        let OperandValue::Ref(place) = thread_op.val else {
156            ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
157        };
158        let thread_val = builder.load(arr_ty, place.llval, four);
159
160        fn mul_dim3<'ll, 'tcx>(
161            builder: &mut Builder<'_, 'll, 'tcx>,
162            arr: &'ll Value,
163        ) -> &'ll Value {
164            let x = builder.extract_value(arr, 0);
165            let y = builder.extract_value(arr, 1);
166            let z = builder.extract_value(arr, 2);
167
168            let xy = builder.mul(x, y);
169            builder.mul(xy, z)
170        }
171
172        let num_workgroups = mul_dim3(builder, workgroup_val);
173        let threads_per_block = mul_dim3(builder, thread_val);
174
175        OffloadKernelDims {
176            workgroup_dims: workgroup_val,
177            thread_dims: thread_val,
178            num_workgroups,
179            threads_per_block,
180        }
181    }
182}
183
184// ; Function Attrs: nounwind
185// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
186fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
187    let tptr = cx.type_ptr();
188    let ti64 = cx.type_i64();
189    let ti32 = cx.type_i32();
190    let args = <[_]>::into_vec(::alloc::boxed::box_new([tptr, ti64, ti32, ti32, tptr, tptr]))vec![tptr, ti64, ti32, ti32, tptr, tptr];
191    let tgt_fn_ty = cx.type_func(&args, ti32);
192    let name = "__tgt_target_kernel";
193    let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
194    let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
195    attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
196    (tgt_decl, tgt_fn_ty)
197}
198
199// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
200// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
201// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
202// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
203// offloaded was defined.
204pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
205    let unknown_txt = ";unknown;unknown;0;0;;";
206    let c_entry_name = CString::new(unknown_txt).unwrap();
207    let c_val = c_entry_name.as_bytes_with_nul();
208    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
209    let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
210    llvm::set_alignment(at_zero, Align::ONE);
211
212    // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
213    let struct_ident_ty = cx.type_named_struct("struct.ident_t");
214    let struct_elems = <[_]>::into_vec(::alloc::boxed::box_new([cx.get_const_i32(0),
                cx.get_const_i32(2), cx.get_const_i32(0),
                cx.get_const_i32(22), at_zero]))vec![
215        cx.get_const_i32(0),
216        cx.get_const_i32(2),
217        cx.get_const_i32(0),
218        cx.get_const_i32(22),
219        at_zero,
220    ];
221    let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
222    let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
223    cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
224    let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
225    llvm::set_alignment(at_one, Align::EIGHT);
226    at_one
227}
228
229pub(crate) struct TgtOffloadEntry {
230    //   uint64_t Reserved;
231    //   uint16_t Version;
232    //   uint16_t Kind;
233    //   uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags)
234    //   void *Address; Address of global symbol within device image (function or global)
235    //   char *SymbolName;
236    //   uint64_t Size; Size of the entry info (0 if it is a function)
237    //   uint64_t Data;
238    //   void *AuxAddr;
239}
240
241impl TgtOffloadEntry {
242    pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
243        let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
244        let tptr = cx.type_ptr();
245        let ti64 = cx.type_i64();
246        let ti32 = cx.type_i32();
247        let ti16 = cx.type_i16();
248        // For each kernel to run on the gpu, we will later generate one entry of this type.
249        // copied from LLVM
250        let entry_elements = <[_]>::into_vec(::alloc::boxed::box_new([ti64, ti16, ti16, ti32, tptr, tptr,
                ti64, ti64, tptr]))vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
251        cx.set_struct_body(offload_entry_ty, &entry_elements, false);
252        offload_entry_ty
253    }
254
255    fn new<'ll>(
256        cx: &CodegenCx<'ll, '_>,
257        region_id: &'ll Value,
258        llglobal: &'ll Value,
259    ) -> [&'ll Value; 9] {
260        let reserved = cx.get_const_i64(0);
261        let version = cx.get_const_i16(1);
262        let kind = cx.get_const_i16(1);
263        let flags = cx.get_const_i32(0);
264        let size = cx.get_const_i64(0);
265        let data = cx.get_const_i64(0);
266        let aux_addr = cx.const_null(cx.type_ptr());
267        [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
268    }
269}
270
271// Taken from the LLVM APITypes.h declaration:
272struct KernelArgsTy {
273    //  uint32_t Version = 0; // Version of this struct for ABI compatibility.
274    //  uint32_t NumArgs = 0; // Number of arguments in each input pointer.
275    //  void **ArgBasePtrs =
276    //      nullptr;                 // Base pointer of each argument (e.g. a struct).
277    //  void **ArgPtrs = nullptr;    // Pointer to the argument data.
278    //  int64_t *ArgSizes = nullptr; // Size of the argument data in bytes.
279    //  int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from).
280    //  void **ArgNames = nullptr;   // Name of the data for debugging, possibly null.
281    //  void **ArgMappers = nullptr; // User-defined mappers, possibly null.
282    //  uint64_t Tripcount =
283    // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
284    // struct {
285    //    uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
286    //    uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
287    //    uint64_t Unused : 62;
288    //  } Flags = {0, 0, 0}; // totals to 64 Bit, 8 Byte
289    //  // The number of teams (for x,y,z dimension).
290    //  uint32_t NumTeams[3] = {0, 0, 0};
291    //  // The number of threads (for x,y,z dimension).
292    //  uint32_t ThreadLimit[3] = {0, 0, 0};
293    //  uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
294}
295
296impl KernelArgsTy {
297    const OFFLOAD_VERSION: u64 = 3;
298    const FLAGS: u64 = 0;
299    const TRIPCOUNT: u64 = 0;
300    fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
301        let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
302        let tptr = cx.type_ptr();
303        let ti64 = cx.type_i64();
304        let ti32 = cx.type_i32();
305        let tarr = cx.type_array(ti32, 3);
306
307        let kernel_elements =
308            <[_]>::into_vec(::alloc::boxed::box_new([ti32, ti32, tptr, tptr, tptr, tptr,
                tptr, tptr, ti64, ti64, tarr, tarr, ti32]))vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
309
310        cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
311        kernel_arguments_ty
312    }
313
314    fn new<'ll, 'tcx>(
315        cx: &CodegenCx<'ll, 'tcx>,
316        num_args: u64,
317        memtransfer_types: &'ll Value,
318        geps: [&'ll Value; 3],
319        workgroup_dims: &'ll Value,
320        thread_dims: &'ll Value,
321    ) -> [(Align, &'ll Value); 13] {
322        let four = Align::from_bytes(4).expect("4 Byte alignment should work");
323        let eight = Align::EIGHT;
324
325        [
326            (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
327            (four, cx.get_const_i32(num_args)),
328            (eight, geps[0]),
329            (eight, geps[1]),
330            (eight, geps[2]),
331            (eight, memtransfer_types),
332            // The next two are debug infos. FIXME(offload): set them
333            (eight, cx.const_null(cx.type_ptr())), // dbg
334            (eight, cx.const_null(cx.type_ptr())), // dbg
335            (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
336            (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
337            (four, workgroup_dims),
338            (four, thread_dims),
339            (four, cx.get_const_i32(0)),
340        ]
341    }
342}
343
344// Contains LLVM values needed to manage offloading for a single kernel.
345#[derive(#[automatically_derived]
impl<'ll> ::core::marker::Copy for OffloadKernelGlobals<'ll> { }Copy, #[automatically_derived]
impl<'ll> ::core::clone::Clone for OffloadKernelGlobals<'ll> {
    #[inline]
    fn clone(&self) -> OffloadKernelGlobals<'ll> {
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        *self
    }
}Clone)]
346pub(crate) struct OffloadKernelGlobals<'ll> {
347    pub offload_sizes: &'ll llvm::Value,
348    pub memtransfer_types: &'ll llvm::Value,
349    pub region_id: &'ll llvm::Value,
350    pub offload_entry: &'ll llvm::Value,
351}
352
353fn gen_tgt_data_mappers<'ll>(
354    cx: &CodegenCx<'ll, '_>,
355) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
356    let tptr = cx.type_ptr();
357    let ti64 = cx.type_i64();
358    let ti32 = cx.type_i32();
359
360    let args = <[_]>::into_vec(::alloc::boxed::box_new([tptr, ti64, ti32, tptr, tptr, tptr,
                tptr, tptr, tptr]))vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
361    let mapper_fn_ty = cx.type_func(&args, cx.type_void());
362    let mapper_begin = "__tgt_target_data_begin_mapper";
363    let mapper_update = "__tgt_target_data_update_mapper";
364    let mapper_end = "__tgt_target_data_end_mapper";
365    let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
366    let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
367    let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
368
369    let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
370    attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
371    attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
372    attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
373
374    (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
375}
376
377fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
378    let ti64 = cx.type_i64();
379    let mut size_val = Vec::with_capacity(vals.len());
380    for &val in vals {
381        size_val.push(cx.get_const_i64(val));
382    }
383    let initializer = cx.const_array(ti64, &size_val);
384    add_unnamed_global(cx, name, initializer, PrivateLinkage)
385}
386
387pub(crate) fn add_unnamed_global<'ll>(
388    cx: &SimpleCx<'ll>,
389    name: &str,
390    initializer: &'ll llvm::Value,
391    l: Linkage,
392) -> &'ll llvm::Value {
393    let llglobal = add_global(cx, name, initializer, l);
394    llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
395    llglobal
396}
397
398pub(crate) fn add_global<'ll>(
399    cx: &SimpleCx<'ll>,
400    name: &str,
401    initializer: &'ll llvm::Value,
402    l: Linkage,
403) -> &'ll llvm::Value {
404    let c_name = CString::new(name).unwrap();
405    let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
406    llvm::set_global_constant(llglobal, true);
407    llvm::set_linkage(llglobal, l);
408    llvm::set_initializer(llglobal, initializer);
409    llglobal
410}
411
412// This function returns a memtransfer value which encodes how arguments to this kernel shall be
413// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
414// concatenated into the list of region_ids.
415pub(crate) fn gen_define_handling<'ll>(
416    cx: &CodegenCx<'ll, '_>,
417    metadata: &[OffloadMetadata],
418    symbol: String,
419    offload_globals: &OffloadGlobals<'ll>,
420) -> OffloadKernelGlobals<'ll> {
421    if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
422        return *entry;
423    }
424
425    let offload_entry_ty = offload_globals.offload_entry_ty;
426
427    // FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
428    let (sizes, transfer): (Vec<_>, Vec<_>) =
429        metadata.iter().map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
430
431    let offload_sizes = add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offload_sizes.{0}", symbol))
    })format!(".offload_sizes.{symbol}"), &sizes);
432    // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
433    // or both to and from the gpu (=3). Other values shouldn't affect us for now.
434    // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
435    // will be 2. For now, everything is 3, until we have our frontend set up.
436    // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
437    let memtransfer_types =
438        add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offload_maptypes.{0}", symbol))
    })format!(".offload_maptypes.{symbol}"), &transfer);
439
440    // Next: For each function, generate these three entries. A weak constant,
441    // the llvm.rodata entry name, and  the llvm_offload_entries value
442
443    let name = ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".{0}.region_id", symbol))
    })format!(".{symbol}.region_id");
444    let initializer = cx.get_const_i8(0);
445    let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
446
447    let c_entry_name = CString::new(symbol.clone()).unwrap();
448    let c_val = c_entry_name.as_bytes_with_nul();
449    let offload_entry_name = ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offloading.entry_name.{0}",
                symbol))
    })format!(".offloading.entry_name.{symbol}");
450
451    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
452    let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
453    llvm::set_alignment(llglobal, Align::ONE);
454    llvm::set_section(llglobal, c".llvm.rodata.offloading");
455
456    let name = ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offloading.entry.{0}", symbol))
    })format!(".offloading.entry.{symbol}");
457
458    // See the __tgt_offload_entry documentation above.
459    let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
460
461    let initializer = crate::common::named_struct(offload_entry_ty, &elems);
462    let c_name = CString::new(name).unwrap();
463    let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
464    llvm::set_global_constant(offload_entry, true);
465    llvm::set_linkage(offload_entry, WeakAnyLinkage);
466    llvm::set_initializer(offload_entry, initializer);
467    llvm::set_alignment(offload_entry, Align::EIGHT);
468    let c_section_name = CString::new("llvm_offload_entries").unwrap();
469    llvm::set_section(offload_entry, &c_section_name);
470
471    let result =
472        OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
473
474    cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
475
476    result
477}
478
479fn declare_offload_fn<'ll>(
480    cx: &CodegenCx<'ll, '_>,
481    name: &str,
482    ty: &'ll llvm::Type,
483) -> &'ll llvm::Value {
484    crate::declare::declare_simple_fn(
485        cx,
486        name,
487        llvm::CallConv::CCallConv,
488        llvm::UnnamedAddr::No,
489        llvm::Visibility::Default,
490        ty,
491    )
492}
493
494pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
495    match cx.type_kind(ty) {
496        TypeKind::Half
497        | TypeKind::Float
498        | TypeKind::Double
499        | TypeKind::X86_FP80
500        | TypeKind::FP128
501        | TypeKind::PPC_FP128 => cx.float_width(ty) as u64,
502        TypeKind::Integer => cx.int_width(ty),
503        other => ::rustc_middle::util::bug::bug_fmt(format_args!("scalar_width was called on a non scalar type {0:?}",
        other))bug!("scalar_width was called on a non scalar type {other:?}"),
504    }
505}
506
507// For each kernel *call*, we now use some of our previous declared globals to move data to and from
508// the gpu. For now, we only handle the data transfer part of it.
509// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
510// Since in our frontend users (by default) don't have to specify data transfer, this is something
511// we should optimize in the future! In some cases we can directly zero-allocate on the device and
512// only move data back, or if something is immutable, we might only copy it to the device.
513//
514// Current steps:
515// 0. Alloca some variables for the following steps
516// 1. set insert point before kernel call.
517// 2. generate all the GEPS and stores, to be used in 3)
518// 3. generate __tgt_target_data_begin calls to move data to the GPU
519//
520// unchanged: keep kernel call. Later move the kernel to the GPU
521//
522// 4. set insert point after kernel call.
523// 5. generate all the GEPS and stores, to be used in 6)
524// 6. generate __tgt_target_data_end calls to move data from the GPU
525pub(crate) fn gen_call_handling<'ll, 'tcx>(
526    builder: &mut Builder<'_, 'll, 'tcx>,
527    offload_data: &OffloadKernelGlobals<'ll>,
528    args: &[&'ll Value],
529    types: &[&Type],
530    metadata: &[OffloadMetadata],
531    offload_globals: &OffloadGlobals<'ll>,
532    offload_dims: &OffloadKernelDims<'ll>,
533) {
534    let cx = builder.cx;
535    let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
536        offload_data;
537    let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
538        offload_dims;
539
540    let tgt_decl = offload_globals.launcher_fn;
541    let tgt_target_kernel_ty = offload_globals.launcher_ty;
542
543    let tgt_kernel_decl = offload_globals.kernel_args_ty;
544    let begin_mapper_decl = offload_globals.begin_mapper;
545    let end_mapper_decl = offload_globals.end_mapper;
546    let fn_ty = offload_globals.mapper_fn_ty;
547
548    let num_args = types.len() as u64;
549    let bb = builder.llbb();
550
551    // FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these
552    // variables from being optimized away
553    for val in [offload_sizes, offload_entry] {
554        unsafe {
555            let dummy = llvm::LLVMBuildLoad2(
556                &builder.llbuilder,
557                llvm::LLVMTypeOf(val),
558                val,
559                b"dummy\0".as_ptr() as *const _,
560            );
561            llvm::LLVMSetVolatile(dummy, llvm::TRUE);
562        }
563    }
564
565    // Step 0)
566    unsafe {
567        llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
568    }
569
570    let ty = cx.type_array(cx.type_ptr(), num_args);
571    // Baseptr are just the input pointer to the kernel, stored in a local alloca
572    let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
573    // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
574    let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
575    // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
576    let ty2 = cx.type_array(cx.type_i64(), num_args);
577    let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
578
579    //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
580    let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
581
582    // Step 1)
583    unsafe {
584        llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
585    }
586
587    // Now we allocate once per function param, a copy to be passed to one of our maps.
588    let mut vals = ::alloc::vec::Vec::new()vec![];
589    let mut geps = ::alloc::vec::Vec::new()vec![];
590    let i32_0 = cx.get_const_i32(0);
591    for &v in args {
592        let ty = cx.val_ty(v);
593        let ty_kind = cx.type_kind(ty);
594        let (base_val, gep_base) = match ty_kind {
595            TypeKind::Pointer => (v, v),
596            TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
597                // FIXME(Sa4dUs): check for `f128` support, latest NVIDIA cards support it
598                let num_bits = scalar_width(cx, ty);
599
600                let bb = builder.llbb();
601                unsafe {
602                    llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, builder.llfn());
603                }
604                let addr = builder.direct_alloca(cx.type_i64(), Align::EIGHT, "addr");
605                unsafe {
606                    llvm::LLVMPositionBuilderAtEnd(builder.llbuilder, bb);
607                }
608
609                let cast = builder.bitcast(v, cx.type_ix(num_bits));
610                let value = builder.zext(cast, cx.type_i64());
611                builder.store(value, addr, Align::EIGHT);
612                (value, addr)
613            }
614            other => ::rustc_middle::util::bug::bug_fmt(format_args!("offload does not support {0:?}",
        other))bug!("offload does not support {other:?}"),
615        };
616
617        let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]);
618
619        vals.push(base_val);
620        geps.push(gep);
621    }
622
623    let init_ty = cx.type_func(&[], cx.type_void());
624    let init_rtls_decl = offload_globals.init_rtls;
625
626    // call void @__tgt_init_all_rtls()
627    builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
628
629    for i in 0..num_args {
630        let idx = cx.get_const_i32(i);
631        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
632        builder.store(vals[i as usize], gep1, Align::EIGHT);
633        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
634        builder.store(geps[i as usize], gep2, Align::EIGHT);
635        let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
636        // FIXME(offload): write an offload frontend and handle arbitrary types.
637        builder.store(cx.get_const_i64(metadata[i as usize].payload_size), gep3, Align::EIGHT);
638    }
639
640    // For now we have a very simplistic indexing scheme into our
641    // offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
642    fn get_geps<'ll, 'tcx>(
643        builder: &mut Builder<'_, 'll, 'tcx>,
644        ty: &'ll Type,
645        ty2: &'ll Type,
646        a1: &'ll Value,
647        a2: &'ll Value,
648        a4: &'ll Value,
649    ) -> [&'ll Value; 3] {
650        let cx = builder.cx;
651        let i32_0 = cx.get_const_i32(0);
652
653        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
654        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
655        let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
656        [gep1, gep2, gep3]
657    }
658
659    fn generate_mapper_call<'ll, 'tcx>(
660        builder: &mut Builder<'_, 'll, 'tcx>,
661        geps: [&'ll Value; 3],
662        o_type: &'ll Value,
663        fn_to_call: &'ll Value,
664        fn_ty: &'ll Type,
665        num_args: u64,
666        s_ident_t: &'ll Value,
667    ) {
668        let cx = builder.cx;
669        let nullptr = cx.const_null(cx.type_ptr());
670        let i64_max = cx.get_const_i64(u64::MAX);
671        let num_args = cx.get_const_i32(num_args);
672        let args =
673            <[_]>::into_vec(::alloc::boxed::box_new([s_ident_t, i64_max, num_args,
                geps[0], geps[1], geps[2], o_type, nullptr, nullptr]))vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
674        builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
675    }
676
677    // Step 2)
678    let s_ident_t = offload_globals.ident_t_global;
679    let geps = get_geps(builder, ty, ty2, a1, a2, a4);
680    generate_mapper_call(
681        builder,
682        geps,
683        memtransfer_types,
684        begin_mapper_decl,
685        fn_ty,
686        num_args,
687        s_ident_t,
688    );
689    let values =
690        KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
691
692    // Step 3)
693    // Here we fill the KernelArgsTy, see the documentation above
694    for (i, value) in values.iter().enumerate() {
695        let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
696        builder.store(value.1, ptr, value.0);
697    }
698
699    let args = <[_]>::into_vec(::alloc::boxed::box_new([s_ident_t,
                cx.get_const_i64(u64::MAX), num_workgroups, threads_per_block,
                region_id, a5]))vec![
700        s_ident_t,
701        // FIXME(offload) give users a way to select which GPU to use.
702        cx.get_const_i64(u64::MAX), // MAX == -1.
703        num_workgroups,
704        threads_per_block,
705        region_id,
706        a5,
707    ];
708    builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
709    // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
710
711    // Step 4)
712    let geps = get_geps(builder, ty, ty2, a1, a2, a4);
713    generate_mapper_call(
714        builder,
715        geps,
716        memtransfer_types,
717        end_mapper_decl,
718        fn_ty,
719        num_args,
720        s_ident_t,
721    );
722}