rustc_codegen_llvm/builder/
gpu_offload.rs

1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
6use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7use rustc_middle::bug;
8use rustc_middle::ty::offload_meta::OffloadMetadata;
9
10use crate::builder::Builder;
11use crate::common::CodegenCx;
12use crate::llvm::AttributePlace::Function;
13use crate::llvm::{self, Linkage, Type, Value};
14use crate::{SimpleCx, attributes};
15
16// LLVM kernel-independent globals required for offloading
17pub(crate) struct OffloadGlobals<'ll> {
18    pub launcher_fn: &'ll llvm::Value,
19    pub launcher_ty: &'ll llvm::Type,
20
21    pub bin_desc: &'ll llvm::Type,
22
23    pub kernel_args_ty: &'ll llvm::Type,
24
25    pub offload_entry_ty: &'ll llvm::Type,
26
27    pub begin_mapper: &'ll llvm::Value,
28    pub end_mapper: &'ll llvm::Value,
29    pub mapper_fn_ty: &'ll llvm::Type,
30
31    pub ident_t_global: &'ll llvm::Value,
32
33    pub register_lib: &'ll llvm::Value,
34    pub unregister_lib: &'ll llvm::Value,
35    pub init_rtls: &'ll llvm::Value,
36}
37
38impl<'ll> OffloadGlobals<'ll> {
39    pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
40        let (launcher_fn, launcher_ty) = generate_launcher(cx);
41        let kernel_args_ty = KernelArgsTy::new_decl(cx);
42        let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
43        let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
44        let ident_t_global = generate_at_one(cx);
45
46        let tptr = cx.type_ptr();
47        let ti32 = cx.type_i32();
48        let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
49        let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
50        cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);
51
52        let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
53        let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl);
54        let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
55        let init_ty = cx.type_func(&[], cx.type_void());
56        let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
57
58        OffloadGlobals {
59            launcher_fn,
60            launcher_ty,
61            bin_desc,
62            kernel_args_ty,
63            offload_entry_ty,
64            begin_mapper,
65            end_mapper,
66            mapper_fn_ty,
67            ident_t_global,
68            register_lib,
69            unregister_lib,
70            init_rtls,
71        }
72    }
73}
74
75pub(crate) struct OffloadKernelDims<'ll> {
76    num_workgroups: &'ll Value,
77    threads_per_block: &'ll Value,
78    workgroup_dims: &'ll Value,
79    thread_dims: &'ll Value,
80}
81
82impl<'ll> OffloadKernelDims<'ll> {
83    pub(crate) fn from_operands<'tcx>(
84        builder: &mut Builder<'_, 'll, 'tcx>,
85        workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
86        thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
87    ) -> Self {
88        let cx = builder.cx;
89        let arr_ty = cx.type_array(cx.type_i32(), 3);
90        let four = Align::from_bytes(4).unwrap();
91
92        let OperandValue::Ref(place) = workgroup_op.val else {
93            bug!("expected array operand by reference");
94        };
95        let workgroup_val = builder.load(arr_ty, place.llval, four);
96
97        let OperandValue::Ref(place) = thread_op.val else {
98            bug!("expected array operand by reference");
99        };
100        let thread_val = builder.load(arr_ty, place.llval, four);
101
102        fn mul_dim3<'ll, 'tcx>(
103            builder: &mut Builder<'_, 'll, 'tcx>,
104            arr: &'ll Value,
105        ) -> &'ll Value {
106            let x = builder.extract_value(arr, 0);
107            let y = builder.extract_value(arr, 1);
108            let z = builder.extract_value(arr, 2);
109
110            let xy = builder.mul(x, y);
111            builder.mul(xy, z)
112        }
113
114        let num_workgroups = mul_dim3(builder, workgroup_val);
115        let threads_per_block = mul_dim3(builder, thread_val);
116
117        OffloadKernelDims {
118            workgroup_dims: workgroup_val,
119            thread_dims: thread_val,
120            num_workgroups,
121            threads_per_block,
122        }
123    }
124}
125
126// ; Function Attrs: nounwind
127// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
128fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
129    let tptr = cx.type_ptr();
130    let ti64 = cx.type_i64();
131    let ti32 = cx.type_i32();
132    let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
133    let tgt_fn_ty = cx.type_func(&args, ti32);
134    let name = "__tgt_target_kernel";
135    let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
136    let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
137    attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
138    (tgt_decl, tgt_fn_ty)
139}
140
141// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
142// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
143// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
144// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
145// offloaded was defined.
146pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
147    let unknown_txt = ";unknown;unknown;0;0;;";
148    let c_entry_name = CString::new(unknown_txt).unwrap();
149    let c_val = c_entry_name.as_bytes_with_nul();
150    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
151    let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
152    llvm::set_alignment(at_zero, Align::ONE);
153
154    // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
155    let struct_ident_ty = cx.type_named_struct("struct.ident_t");
156    let struct_elems = vec![
157        cx.get_const_i32(0),
158        cx.get_const_i32(2),
159        cx.get_const_i32(0),
160        cx.get_const_i32(22),
161        at_zero,
162    ];
163    let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
164    let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
165    cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
166    let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
167    llvm::set_alignment(at_one, Align::EIGHT);
168    at_one
169}
170
171pub(crate) struct TgtOffloadEntry {
172    //   uint64_t Reserved;
173    //   uint16_t Version;
174    //   uint16_t Kind;
175    //   uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags)
176    //   void *Address; Address of global symbol within device image (function or global)
177    //   char *SymbolName;
178    //   uint64_t Size; Size of the entry info (0 if it is a function)
179    //   uint64_t Data;
180    //   void *AuxAddr;
181}
182
183impl TgtOffloadEntry {
184    pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
185        let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
186        let tptr = cx.type_ptr();
187        let ti64 = cx.type_i64();
188        let ti32 = cx.type_i32();
189        let ti16 = cx.type_i16();
190        // For each kernel to run on the gpu, we will later generate one entry of this type.
191        // copied from LLVM
192        let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
193        cx.set_struct_body(offload_entry_ty, &entry_elements, false);
194        offload_entry_ty
195    }
196
197    fn new<'ll>(
198        cx: &CodegenCx<'ll, '_>,
199        region_id: &'ll Value,
200        llglobal: &'ll Value,
201    ) -> [&'ll Value; 9] {
202        let reserved = cx.get_const_i64(0);
203        let version = cx.get_const_i16(1);
204        let kind = cx.get_const_i16(1);
205        let flags = cx.get_const_i32(0);
206        let size = cx.get_const_i64(0);
207        let data = cx.get_const_i64(0);
208        let aux_addr = cx.const_null(cx.type_ptr());
209        [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
210    }
211}
212
213// Taken from the LLVM APITypes.h declaration:
214struct KernelArgsTy {
215    //  uint32_t Version = 0; // Version of this struct for ABI compatibility.
216    //  uint32_t NumArgs = 0; // Number of arguments in each input pointer.
217    //  void **ArgBasePtrs =
218    //      nullptr;                 // Base pointer of each argument (e.g. a struct).
219    //  void **ArgPtrs = nullptr;    // Pointer to the argument data.
220    //  int64_t *ArgSizes = nullptr; // Size of the argument data in bytes.
221    //  int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from).
222    //  void **ArgNames = nullptr;   // Name of the data for debugging, possibly null.
223    //  void **ArgMappers = nullptr; // User-defined mappers, possibly null.
224    //  uint64_t Tripcount =
225    // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
226    // struct {
227    //    uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
228    //    uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
229    //    uint64_t Unused : 62;
230    //  } Flags = {0, 0, 0}; // totals to 64 Bit, 8 Byte
231    //  // The number of teams (for x,y,z dimension).
232    //  uint32_t NumTeams[3] = {0, 0, 0};
233    //  // The number of threads (for x,y,z dimension).
234    //  uint32_t ThreadLimit[3] = {0, 0, 0};
235    //  uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
236}
237
238impl KernelArgsTy {
239    const OFFLOAD_VERSION: u64 = 3;
240    const FLAGS: u64 = 0;
241    const TRIPCOUNT: u64 = 0;
242    fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
243        let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
244        let tptr = cx.type_ptr();
245        let ti64 = cx.type_i64();
246        let ti32 = cx.type_i32();
247        let tarr = cx.type_array(ti32, 3);
248
249        let kernel_elements =
250            vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
251
252        cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
253        kernel_arguments_ty
254    }
255
256    fn new<'ll, 'tcx>(
257        cx: &CodegenCx<'ll, 'tcx>,
258        num_args: u64,
259        memtransfer_types: &'ll Value,
260        geps: [&'ll Value; 3],
261        workgroup_dims: &'ll Value,
262        thread_dims: &'ll Value,
263    ) -> [(Align, &'ll Value); 13] {
264        let four = Align::from_bytes(4).expect("4 Byte alignment should work");
265        let eight = Align::EIGHT;
266
267        [
268            (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
269            (four, cx.get_const_i32(num_args)),
270            (eight, geps[0]),
271            (eight, geps[1]),
272            (eight, geps[2]),
273            (eight, memtransfer_types),
274            // The next two are debug infos. FIXME(offload): set them
275            (eight, cx.const_null(cx.type_ptr())), // dbg
276            (eight, cx.const_null(cx.type_ptr())), // dbg
277            (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
278            (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
279            (four, workgroup_dims),
280            (four, thread_dims),
281            (four, cx.get_const_i32(0)),
282        ]
283    }
284}
285
286// Contains LLVM values needed to manage offloading for a single kernel.
287#[derive(Copy, Clone)]
288pub(crate) struct OffloadKernelGlobals<'ll> {
289    pub offload_sizes: &'ll llvm::Value,
290    pub memtransfer_types: &'ll llvm::Value,
291    pub region_id: &'ll llvm::Value,
292    pub offload_entry: &'ll llvm::Value,
293}
294
295fn gen_tgt_data_mappers<'ll>(
296    cx: &CodegenCx<'ll, '_>,
297) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
298    let tptr = cx.type_ptr();
299    let ti64 = cx.type_i64();
300    let ti32 = cx.type_i32();
301
302    let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
303    let mapper_fn_ty = cx.type_func(&args, cx.type_void());
304    let mapper_begin = "__tgt_target_data_begin_mapper";
305    let mapper_update = "__tgt_target_data_update_mapper";
306    let mapper_end = "__tgt_target_data_end_mapper";
307    let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
308    let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
309    let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
310
311    let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
312    attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
313    attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
314    attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
315
316    (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
317}
318
319fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
320    let ti64 = cx.type_i64();
321    let mut size_val = Vec::with_capacity(vals.len());
322    for &val in vals {
323        size_val.push(cx.get_const_i64(val));
324    }
325    let initializer = cx.const_array(ti64, &size_val);
326    add_unnamed_global(cx, name, initializer, PrivateLinkage)
327}
328
329pub(crate) fn add_unnamed_global<'ll>(
330    cx: &SimpleCx<'ll>,
331    name: &str,
332    initializer: &'ll llvm::Value,
333    l: Linkage,
334) -> &'ll llvm::Value {
335    let llglobal = add_global(cx, name, initializer, l);
336    llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
337    llglobal
338}
339
340pub(crate) fn add_global<'ll>(
341    cx: &SimpleCx<'ll>,
342    name: &str,
343    initializer: &'ll llvm::Value,
344    l: Linkage,
345) -> &'ll llvm::Value {
346    let c_name = CString::new(name).unwrap();
347    let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
348    llvm::set_global_constant(llglobal, true);
349    llvm::set_linkage(llglobal, l);
350    llvm::set_initializer(llglobal, initializer);
351    llglobal
352}
353
354// This function returns a memtransfer value which encodes how arguments to this kernel shall be
355// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
356// concatenated into the list of region_ids.
357pub(crate) fn gen_define_handling<'ll>(
358    cx: &CodegenCx<'ll, '_>,
359    metadata: &[OffloadMetadata],
360    types: &[&'ll Type],
361    symbol: String,
362    offload_globals: &OffloadGlobals<'ll>,
363) -> OffloadKernelGlobals<'ll> {
364    if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
365        return *entry;
366    }
367
368    let offload_entry_ty = offload_globals.offload_entry_ty;
369
370    // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
371    // reference) types.
372    let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
373        rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
374        _ => None,
375    });
376
377    // FIXME(Sa4dUs): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
378    let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
379        ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
380
381    let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
382    // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
383    // or both to and from the gpu (=3). Other values shouldn't affect us for now.
384    // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
385    // will be 2. For now, everything is 3, until we have our frontend set up.
386    // 1+2+32: 1 (MapTo), 2 (MapFrom), 32 (Add one extra input ptr per function, to be used later).
387    let memtransfer_types =
388        add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer);
389
390    // Next: For each function, generate these three entries. A weak constant,
391    // the llvm.rodata entry name, and  the llvm_offload_entries value
392
393    let name = format!(".{symbol}.region_id");
394    let initializer = cx.get_const_i8(0);
395    let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
396
397    let c_entry_name = CString::new(symbol.clone()).unwrap();
398    let c_val = c_entry_name.as_bytes_with_nul();
399    let offload_entry_name = format!(".offloading.entry_name.{symbol}");
400
401    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
402    let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
403    llvm::set_alignment(llglobal, Align::ONE);
404    llvm::set_section(llglobal, c".llvm.rodata.offloading");
405
406    let name = format!(".offloading.entry.{symbol}");
407
408    // See the __tgt_offload_entry documentation above.
409    let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
410
411    let initializer = crate::common::named_struct(offload_entry_ty, &elems);
412    let c_name = CString::new(name).unwrap();
413    let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
414    llvm::set_global_constant(offload_entry, true);
415    llvm::set_linkage(offload_entry, WeakAnyLinkage);
416    llvm::set_initializer(offload_entry, initializer);
417    llvm::set_alignment(offload_entry, Align::EIGHT);
418    let c_section_name = CString::new("llvm_offload_entries").unwrap();
419    llvm::set_section(offload_entry, &c_section_name);
420
421    let result =
422        OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
423
424    cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
425
426    result
427}
428
429fn declare_offload_fn<'ll>(
430    cx: &CodegenCx<'ll, '_>,
431    name: &str,
432    ty: &'ll llvm::Type,
433) -> &'ll llvm::Value {
434    crate::declare::declare_simple_fn(
435        cx,
436        name,
437        llvm::CallConv::CCallConv,
438        llvm::UnnamedAddr::No,
439        llvm::Visibility::Default,
440        ty,
441    )
442}
443
444// For each kernel *call*, we now use some of our previous declared globals to move data to and from
445// the gpu. For now, we only handle the data transfer part of it.
446// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
447// Since in our frontend users (by default) don't have to specify data transfer, this is something
448// we should optimize in the future! We also assume that everything should be copied back and forth,
449// but sometimes we can directly zero-allocate on the device and only move back, or if something is
450// immutable, we might only copy it to the device, but not back.
451//
452// Current steps:
453// 0. Alloca some variables for the following steps
454// 1. set insert point before kernel call.
455// 2. generate all the GEPS and stores, to be used in 3)
456// 3. generate __tgt_target_data_begin calls to move data to the GPU
457//
458// unchanged: keep kernel call. Later move the kernel to the GPU
459//
460// 4. set insert point after kernel call.
461// 5. generate all the GEPS and stores, to be used in 6)
462// 6. generate __tgt_target_data_end calls to move data from the GPU
463pub(crate) fn gen_call_handling<'ll, 'tcx>(
464    builder: &mut Builder<'_, 'll, 'tcx>,
465    offload_data: &OffloadKernelGlobals<'ll>,
466    args: &[&'ll Value],
467    types: &[&Type],
468    metadata: &[OffloadMetadata],
469    offload_globals: &OffloadGlobals<'ll>,
470    offload_dims: &OffloadKernelDims<'ll>,
471) {
472    let cx = builder.cx;
473    let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
474        offload_data;
475    let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
476        offload_dims;
477
478    let tgt_decl = offload_globals.launcher_fn;
479    let tgt_target_kernel_ty = offload_globals.launcher_ty;
480
481    // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
482    let tgt_bin_desc = offload_globals.bin_desc;
483
484    let tgt_kernel_decl = offload_globals.kernel_args_ty;
485    let begin_mapper_decl = offload_globals.begin_mapper;
486    let end_mapper_decl = offload_globals.end_mapper;
487    let fn_ty = offload_globals.mapper_fn_ty;
488
489    let num_args = types.len() as u64;
490    let bb = builder.llbb();
491
492    // FIXME(Sa4dUs): dummy loads are a temp workaround, we should find a proper way to prevent these
493    // variables from being optimized away
494    for val in [offload_sizes, offload_entry] {
495        unsafe {
496            let dummy = llvm::LLVMBuildLoad2(
497                &builder.llbuilder,
498                llvm::LLVMTypeOf(val),
499                val,
500                b"dummy\0".as_ptr() as *const _,
501            );
502            llvm::LLVMSetVolatile(dummy, llvm::TRUE);
503        }
504    }
505
506    // Step 0)
507    // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
508    // %6 = alloca %struct.__tgt_bin_desc, align 8
509    unsafe {
510        llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
511    }
512    let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
513
514    let ty = cx.type_array(cx.type_ptr(), num_args);
515    // Baseptr are just the input pointer to the kernel, stored in a local alloca
516    let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
517    // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
518    let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
519    // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
520    let ty2 = cx.type_array(cx.type_i64(), num_args);
521    let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
522
523    //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
524    let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
525
526    // Step 1)
527    unsafe {
528        llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
529    }
530    builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
531
532    // Now we allocate once per function param, a copy to be passed to one of our maps.
533    let mut vals = vec![];
534    let mut geps = vec![];
535    let i32_0 = cx.get_const_i32(0);
536    for &v in args {
537        let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
538        vals.push(v);
539        geps.push(gep);
540    }
541
542    let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
543    let register_lib_decl = offload_globals.register_lib;
544    let unregister_lib_decl = offload_globals.unregister_lib;
545    let init_ty = cx.type_func(&[], cx.type_void());
546    let init_rtls_decl = offload_globals.init_rtls;
547
548    // FIXME(offload): Later we want to add them to the wrapper code, rather than our main function.
549    // call void @__tgt_register_lib(ptr noundef %6)
550    builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None);
551    // call void @__tgt_init_all_rtls()
552    builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
553
554    for i in 0..num_args {
555        let idx = cx.get_const_i32(i);
556        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
557        builder.store(vals[i as usize], gep1, Align::EIGHT);
558        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
559        builder.store(geps[i as usize], gep2, Align::EIGHT);
560        let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
561        // FIXME(offload): write an offload frontend and handle arbitrary types.
562        builder.store(cx.get_const_i64(metadata[i as usize].payload_size), gep3, Align::EIGHT);
563    }
564
565    // For now we have a very simplistic indexing scheme into our
566    // offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
567    fn get_geps<'ll, 'tcx>(
568        builder: &mut Builder<'_, 'll, 'tcx>,
569        ty: &'ll Type,
570        ty2: &'ll Type,
571        a1: &'ll Value,
572        a2: &'ll Value,
573        a4: &'ll Value,
574    ) -> [&'ll Value; 3] {
575        let cx = builder.cx;
576        let i32_0 = cx.get_const_i32(0);
577
578        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
579        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
580        let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
581        [gep1, gep2, gep3]
582    }
583
584    fn generate_mapper_call<'ll, 'tcx>(
585        builder: &mut Builder<'_, 'll, 'tcx>,
586        geps: [&'ll Value; 3],
587        o_type: &'ll Value,
588        fn_to_call: &'ll Value,
589        fn_ty: &'ll Type,
590        num_args: u64,
591        s_ident_t: &'ll Value,
592    ) {
593        let cx = builder.cx;
594        let nullptr = cx.const_null(cx.type_ptr());
595        let i64_max = cx.get_const_i64(u64::MAX);
596        let num_args = cx.get_const_i32(num_args);
597        let args =
598            vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
599        builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
600    }
601
602    // Step 2)
603    let s_ident_t = offload_globals.ident_t_global;
604    let geps = get_geps(builder, ty, ty2, a1, a2, a4);
605    generate_mapper_call(
606        builder,
607        geps,
608        memtransfer_types,
609        begin_mapper_decl,
610        fn_ty,
611        num_args,
612        s_ident_t,
613    );
614    let values =
615        KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
616
617    // Step 3)
618    // Here we fill the KernelArgsTy, see the documentation above
619    for (i, value) in values.iter().enumerate() {
620        let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
621        builder.store(value.1, ptr, value.0);
622    }
623
624    let args = vec![
625        s_ident_t,
626        // FIXME(offload) give users a way to select which GPU to use.
627        cx.get_const_i64(u64::MAX), // MAX == -1.
628        num_workgroups,
629        threads_per_block,
630        region_id,
631        a5,
632    ];
633    builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
634    // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
635
636    // Step 4)
637    let geps = get_geps(builder, ty, ty2, a1, a2, a4);
638    generate_mapper_call(
639        builder,
640        geps,
641        memtransfer_types,
642        end_mapper_decl,
643        fn_ty,
644        num_args,
645        s_ident_t,
646    );
647
648    builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None);
649}