Skip to main content

rustc_codegen_llvm/builder/
gpu_offload.rs

1use std::ffi::CString;
2
3use bitflags::Flags;
4use llvm::Linkage::*;
5use rustc_abi::Align;
6use rustc_codegen_ssa::MemFlags;
7use rustc_codegen_ssa::common::TypeKind;
8use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
9use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
10use rustc_middle::bug;
11use rustc_middle::ty::offload_meta::{MappingFlags, OffloadMetadata, OffloadSize};
12
13use crate::builder::Builder;
14use crate::common::CodegenCx;
15use crate::llvm::AttributePlace::Function;
16use crate::llvm::{self, Linkage, Type, Value};
17use crate::{SimpleCx, attributes};
18
19// LLVM kernel-independent globals required for offloading
20pub(crate) struct OffloadGlobals<'ll> {
21    pub launcher_fn: &'ll llvm::Value,
22    pub launcher_ty: &'ll llvm::Type,
23
24    pub kernel_args_ty: &'ll llvm::Type,
25
26    pub offload_entry_ty: &'ll llvm::Type,
27
28    pub begin_mapper: &'ll llvm::Value,
29    pub end_mapper: &'ll llvm::Value,
30    pub mapper_fn_ty: &'ll llvm::Type,
31
32    pub ident_t_global: &'ll llvm::Value,
33}
34
35impl<'ll> OffloadGlobals<'ll> {
36    pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
37        let (launcher_fn, launcher_ty) = generate_launcher(cx);
38        let kernel_args_ty = KernelArgsTy::new_decl(cx);
39        let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
40        let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
41        let ident_t_global = generate_at_one(cx);
42
43        // We want LLVM's openmp-opt pass to pick up and optimize this module, since it covers both
44        // openmp and offload optimizations.
45        llvm::add_module_flag_u32(cx.llmod(), llvm::ModuleFlagMergeBehavior::Max, "openmp", 51);
46
47        OffloadGlobals {
48            launcher_fn,
49            launcher_ty,
50            kernel_args_ty,
51            offload_entry_ty,
52            begin_mapper,
53            end_mapper,
54            mapper_fn_ty,
55            ident_t_global,
56        }
57    }
58}
59
60// We need to register offload before using it. We also should unregister it once we are done, for
61// good measures. Previously we have done so before and after each individual offload intrinsic
62// call, but that comes at a performance cost. The repeated (un)register calls might also confuse
63// the LLVM ompOpt pass, which tries to move operations to a better location. The easiest solution,
64// which we copy from clang, is to just have those two calls once, in the global ctor/dtor section
65// of the final binary.
66pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
67    // First we check quickly whether we already have done our setup, in which case we return early.
68    // Shouldn't be needed for correctness.
69    let register_lib_name = "__tgt_register_lib";
70    if cx.get_function(register_lib_name).is_some() {
71        return;
72    }
73
74    let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
75    let register_lib = declare_offload_fn(&cx, register_lib_name, reg_lib_decl);
76    let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
77
78    let ptr_null = cx.const_null(cx.type_ptr());
79    let const_struct = cx.const_struct(&[cx.get_const_i32(0), ptr_null, ptr_null, ptr_null], false);
80    let omp_descriptor =
81        add_global(cx, ".omp_offloading.descriptor", const_struct, InternalLinkage);
82    // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 1, ptr @.omp_offloading.device_images, ptr @__start_llvm_offload_entries, ptr @__stop_llvm_offload_entries }
83    // @.omp_offloading.descriptor = internal constant %__tgt_bin_desc { i32 0, ptr null, ptr null, ptr null }
84
85    let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
86    let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
87
88    // FIXME(offload): Drop this, once we fully automated our offload compilation pipeline, since
89    // LLVM will initialize them for us if it sees gpu kernels being registered.
90    let init_ty = cx.type_func(&[], cx.type_void());
91    let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
92
93    let desc_ty = cx.type_func(&[], cx.type_void());
94    let reg_name = ".omp_offloading.descriptor_reg";
95    let unreg_name = ".omp_offloading.descriptor_unreg";
96    let desc_reg_fn = declare_offload_fn(cx, reg_name, desc_ty);
97    let desc_unreg_fn = declare_offload_fn(cx, unreg_name, desc_ty);
98    llvm::set_linkage(desc_reg_fn, InternalLinkage);
99    llvm::set_linkage(desc_unreg_fn, InternalLinkage);
100    llvm::set_section(desc_reg_fn, c".text.startup");
101    llvm::set_section(desc_unreg_fn, c".text.startup");
102
103    // define internal void @.omp_offloading.descriptor_reg() section ".text.startup" {
104    // entry:
105    //   call void @__tgt_register_lib(ptr @.omp_offloading.descriptor)
106    //   call void @__tgt_init_all_rtls()
107    //   %0 = call i32 @atexit(ptr @.omp_offloading.descriptor_unreg)
108    //   ret void
109    // }
110    let bb = Builder::append_block(cx, desc_reg_fn, "entry");
111    let mut a = Builder::build(cx, bb);
112    a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
113    a.call(init_ty, None, None, init_rtls, &[], None, None);
114    a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
115    a.ret_void();
116
117    // define internal void @.omp_offloading.descriptor_unreg() section ".text.startup" {
118    // entry:
119    //   call void @__tgt_unregister_lib(ptr @.omp_offloading.descriptor)
120    //   ret void
121    // }
122    let bb = Builder::append_block(cx, desc_unreg_fn, "entry");
123    let mut a = Builder::build(cx, bb);
124    a.call(reg_lib_decl, None, None, unregister_lib, &[omp_descriptor], None, None);
125    a.ret_void();
126
127    // @llvm.global_ctors = appending global [1 x { i32, ptr, ptr }] [{ i32, ptr, ptr } { i32 101, ptr @.omp_offloading.descriptor_reg, ptr null }]
128    let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [cx.get_const_i32(101), desc_reg_fn, ptr_null]))vec![cx.get_const_i32(101), desc_reg_fn, ptr_null];
129    let const_struct = cx.const_struct(&args, false);
130    let arr = cx.const_array(cx.val_ty(const_struct), &[const_struct]);
131    add_global(cx, "llvm.global_ctors", arr, AppendingLinkage);
132}
133
134pub(crate) struct OffloadKernelDims<'ll> {
135    num_workgroups: &'ll Value,
136    threads_per_block: &'ll Value,
137    workgroup_dims: &'ll Value,
138    thread_dims: &'ll Value,
139}
140
141impl<'ll> OffloadKernelDims<'ll> {
142    pub(crate) fn from_operands<'tcx>(
143        builder: &mut Builder<'_, 'll, 'tcx>,
144        workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
145        thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
146    ) -> Self {
147        let cx = builder.cx;
148        let arr_ty = cx.type_array(cx.type_i32(), 3);
149        let four = Align::from_bytes(4).unwrap();
150
151        let OperandValue::Ref(place) = workgroup_op.val else {
152            ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
153        };
154        let workgroup_val = builder.load(arr_ty, place.llval, four);
155
156        let OperandValue::Ref(place) = thread_op.val else {
157            ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
158        };
159        let thread_val = builder.load(arr_ty, place.llval, four);
160
161        fn mul_dim3<'ll, 'tcx>(
162            builder: &mut Builder<'_, 'll, 'tcx>,
163            arr: &'ll Value,
164        ) -> &'ll Value {
165            let x = builder.extract_value(arr, 0);
166            let y = builder.extract_value(arr, 1);
167            let z = builder.extract_value(arr, 2);
168
169            let xy = builder.mul(x, y);
170            builder.mul(xy, z)
171        }
172
173        let num_workgroups = mul_dim3(builder, workgroup_val);
174        let threads_per_block = mul_dim3(builder, thread_val);
175
176        OffloadKernelDims {
177            workgroup_dims: workgroup_val,
178            thread_dims: thread_val,
179            num_workgroups,
180            threads_per_block,
181        }
182    }
183}
184
185// ; Function Attrs: nounwind
186// declare i32 @__tgt_target_kernel(ptr, i64, i32, i32, ptr, ptr) #2
187fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
188    let tptr = cx.type_ptr();
189    let ti64 = cx.type_i64();
190    let ti32 = cx.type_i32();
191    let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [tptr, ti64, ti32, ti32, tptr, tptr]))vec![tptr, ti64, ti32, ti32, tptr, tptr];
192    let tgt_fn_ty = cx.type_func(&args, ti32);
193    let name = "__tgt_target_kernel";
194    let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
195    let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
196    attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
197    (tgt_decl, tgt_fn_ty)
198}
199
200// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
201// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
202// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
203// FIXME(offload): @0 should include the file name (e.g. lib.rs) in which the function to be
204// offloaded was defined.
205pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
206    let unknown_txt = ";unknown;unknown;0;0;;";
207    let c_entry_name = CString::new(unknown_txt).unwrap();
208    let c_val = c_entry_name.as_bytes_with_nul();
209    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
210    let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
211    llvm::set_alignment(at_zero, Align::ONE);
212
213    // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
214    let struct_ident_ty = cx.type_named_struct("struct.ident_t");
215    let struct_elems = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [cx.get_const_i32(0), cx.get_const_i32(2), cx.get_const_i32(0),
                cx.get_const_i32(22), at_zero]))vec![
216        cx.get_const_i32(0),
217        cx.get_const_i32(2),
218        cx.get_const_i32(0),
219        cx.get_const_i32(22),
220        at_zero,
221    ];
222    let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
223    let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
224    cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
225    let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
226    llvm::set_alignment(at_one, Align::EIGHT);
227    at_one
228}
229
230pub(crate) struct TgtOffloadEntry {
231    //   uint64_t Reserved;
232    //   uint16_t Version;
233    //   uint16_t Kind;
234    //   uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags)
235    //   void *Address; Address of global symbol within device image (function or global)
236    //   char *SymbolName;
237    //   uint64_t Size; Size of the entry info (0 if it is a function)
238    //   uint64_t Data;
239    //   void *AuxAddr;
240}
241
242impl TgtOffloadEntry {
243    pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
244        let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
245        let tptr = cx.type_ptr();
246        let ti64 = cx.type_i64();
247        let ti32 = cx.type_i32();
248        let ti16 = cx.type_i16();
249        // For each kernel to run on the gpu, we will later generate one entry of this type.
250        // copied from LLVM
251        let entry_elements = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]))vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
252        cx.set_struct_body(offload_entry_ty, &entry_elements, false);
253        offload_entry_ty
254    }
255
256    fn new<'ll>(
257        cx: &CodegenCx<'ll, '_>,
258        region_id: &'ll Value,
259        llglobal: &'ll Value,
260    ) -> [&'ll Value; 9] {
261        let reserved = cx.get_const_i64(0);
262        let version = cx.get_const_i16(1);
263        let kind = cx.get_const_i16(1);
264        let flags = cx.get_const_i32(0);
265        let size = cx.get_const_i64(0);
266        let data = cx.get_const_i64(0);
267        let aux_addr = cx.const_null(cx.type_ptr());
268        [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
269    }
270}
271
272// Taken from the LLVM APITypes.h declaration:
273struct KernelArgsTy {
274    //  uint32_t Version = 0; // Version of this struct for ABI compatibility.
275    //  uint32_t NumArgs = 0; // Number of arguments in each input pointer.
276    //  void **ArgBasePtrs =
277    //      nullptr;                 // Base pointer of each argument (e.g. a struct).
278    //  void **ArgPtrs = nullptr;    // Pointer to the argument data.
279    //  int64_t *ArgSizes = nullptr; // Size of the argument data in bytes.
280    //  int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from).
281    //  void **ArgNames = nullptr;   // Name of the data for debugging, possibly null.
282    //  void **ArgMappers = nullptr; // User-defined mappers, possibly null.
283    //  uint64_t Tripcount =
284    // 0; // Tripcount for the teams / distribute loop, 0 otherwise.
285    // struct {
286    //    uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
287    //    uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
288    //    uint64_t Unused : 62;
289    //  } Flags = {0, 0, 0}; // totals to 64 Bit, 8 Byte
290    //  // The number of teams (for x,y,z dimension).
291    //  uint32_t NumTeams[3] = {0, 0, 0};
292    //  // The number of threads (for x,y,z dimension).
293    //  uint32_t ThreadLimit[3] = {0, 0, 0};
294    //  uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
295}
296
297impl KernelArgsTy {
298    const OFFLOAD_VERSION: u64 = 3;
299    const FLAGS: u64 = 0;
300    const TRIPCOUNT: u64 = 0;
301    fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
302        let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
303        let tptr = cx.type_ptr();
304        let ti64 = cx.type_i64();
305        let ti32 = cx.type_i32();
306        let tarr = cx.type_array(ti32, 3);
307
308        let kernel_elements =
309            ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr,
                tarr, ti32]))vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
310
311        cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
312        kernel_arguments_ty
313    }
314
315    fn new<'ll, 'tcx>(
316        cx: &CodegenCx<'ll, 'tcx>,
317        num_args: u64,
318        memtransfer_types: &'ll Value,
319        geps: [&'ll Value; 3],
320        workgroup_dims: &'ll Value,
321        thread_dims: &'ll Value,
322        dyn_cache: &'ll Value,
323    ) -> [(Align, &'ll str, &'ll Value); 13] {
324        let four = Align::from_bytes(4).expect("4 Byte alignment should work");
325        let eight = Align::EIGHT;
326
327        [
328            (four, "Version", cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
329            (four, "NumArgs", cx.get_const_i32(num_args)),
330            (eight, "ArgBasePtrs", geps[0]),
331            (eight, "ArgPtrs", geps[1]),
332            (eight, "ArgSizes", geps[2]),
333            (eight, "ArgTypes", memtransfer_types),
334            // The next two are debug infos. FIXME(offload): set them
335            (eight, "ArgNames", cx.const_null(cx.type_ptr())), // dbg
336            (eight, "ArgMappers", cx.const_null(cx.type_ptr())), // dbg
337            (eight, "Tripcount", cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
338            (eight, "Flags", cx.get_const_i64(KernelArgsTy::FLAGS)),
339            (four, "NumTeams", workgroup_dims),
340            (four, "ThreadLimit", thread_dims),
341            (four, "DynCGroupMem", dyn_cache),
342        ]
343    }
344}
345
346// Contains LLVM values needed to manage offloading for a single kernel.
347#[derive(#[automatically_derived]
impl<'ll> ::core::marker::Copy for OffloadKernelGlobals<'ll> { }Copy, #[automatically_derived]
impl<'ll> ::core::clone::Clone for OffloadKernelGlobals<'ll> {
    #[inline]
    fn clone(&self) -> OffloadKernelGlobals<'ll> {
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
        *self
    }
}Clone)]
348pub(crate) struct OffloadKernelGlobals<'ll> {
349    pub offload_sizes: &'ll llvm::Value,
350    pub memtransfer_begin: &'ll llvm::Value,
351    pub memtransfer_kernel: &'ll llvm::Value,
352    pub memtransfer_end: &'ll llvm::Value,
353    pub region_id: &'ll llvm::Value,
354}
355
356fn gen_tgt_data_mappers<'ll>(
357    cx: &CodegenCx<'ll, '_>,
358) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
359    let tptr = cx.type_ptr();
360    let ti64 = cx.type_i64();
361    let ti32 = cx.type_i32();
362
363    let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr]))vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
364    let mapper_fn_ty = cx.type_func(&args, cx.type_void());
365    let mapper_begin = "__tgt_target_data_begin_mapper";
366    let mapper_update = "__tgt_target_data_update_mapper";
367    let mapper_end = "__tgt_target_data_end_mapper";
368    let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
369    let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
370    let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
371
372    let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
373    attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
374    attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
375    attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
376
377    (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
378}
379
380fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
381    let ti64 = cx.type_i64();
382    let mut size_val = Vec::with_capacity(vals.len());
383    for &val in vals {
384        size_val.push(cx.get_const_i64(val));
385    }
386    let initializer = cx.const_array(ti64, &size_val);
387    add_unnamed_global(cx, name, initializer, PrivateLinkage)
388}
389
390pub(crate) fn add_unnamed_global<'ll>(
391    cx: &SimpleCx<'ll>,
392    name: &str,
393    initializer: &'ll llvm::Value,
394    l: Linkage,
395) -> &'ll llvm::Value {
396    let llglobal = add_global(cx, name, initializer, l);
397    llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
398    llglobal
399}
400
401pub(crate) fn add_global<'ll>(
402    cx: &SimpleCx<'ll>,
403    name: &str,
404    initializer: &'ll llvm::Value,
405    l: Linkage,
406) -> &'ll llvm::Value {
407    let c_name = CString::new(name).unwrap();
408    let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
409    llvm::set_global_constant(llglobal, true);
410    llvm::set_linkage(llglobal, l);
411    llvm::set_initializer(llglobal, initializer);
412    llglobal
413}
414
415// This function returns a memtransfer value which encodes how arguments to this kernel shall be
416// mapped to/from the gpu. It also returns a region_id with the name of this kernel, to be
417// concatenated into the list of region_ids.
418pub(crate) fn gen_define_handling<'ll>(
419    cx: &CodegenCx<'ll, '_>,
420    metadata: &[OffloadMetadata],
421    symbol: String,
422    offload_globals: &OffloadGlobals<'ll>,
423) -> OffloadKernelGlobals<'ll> {
424    if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
425        return *entry;
426    }
427
428    let offload_entry_ty = offload_globals.offload_entry_ty;
429
430    let (sizes, transfer): (Vec<_>, Vec<_>) =
431        metadata.iter().map(|m| (m.payload_size, m.mode)).unzip();
432    // Our begin mapper should only see simplified information about which args have to be
433    // transferred to the device, the end mapper only about which args should be transferred back.
434    // Any information beyond that makes it harder for LLVM's opt pass to evaluate whether it can
435    // safely move (=optimize) the LLVM-IR location of this data transfer. Only the mapping types
436    // mentioned below are handled, so make sure that we don't generate any other ones.
437    let handled_mappings = MappingFlags::TO
438        | MappingFlags::FROM
439        | MappingFlags::TARGET_PARAM
440        | MappingFlags::LITERAL
441        | MappingFlags::IMPLICIT;
442    for arg in &transfer {
443        if true {
    if !!arg.contains_unknown_bits() {
        ::core::panicking::panic("assertion failed: !arg.contains_unknown_bits()")
    };
};debug_assert!(!arg.contains_unknown_bits());
444        if true {
    if !handled_mappings.contains(*arg) {
        ::core::panicking::panic("assertion failed: handled_mappings.contains(*arg)")
    };
};debug_assert!(handled_mappings.contains(*arg));
445    }
446
447    let valid_begin_mappings = MappingFlags::TO | MappingFlags::LITERAL | MappingFlags::IMPLICIT;
448    let transfer_to: Vec<u64> =
449        transfer.iter().map(|m| m.intersection(valid_begin_mappings).bits()).collect();
450    let transfer_from: Vec<u64> =
451        transfer.iter().map(|m| m.intersection(MappingFlags::FROM).bits()).collect();
452    let valid_kernel_mappings = MappingFlags::LITERAL | MappingFlags::IMPLICIT;
453    // FIXME(offload): add `OMP_MAP_TARGET_PARAM = 0x20` only if necessary
454    let transfer_kernel: Vec<u64> = transfer
455        .iter()
456        .map(|m| (m.intersection(valid_kernel_mappings) | MappingFlags::TARGET_PARAM).bits())
457        .collect();
458
459    let actual_sizes = sizes
460        .iter()
461        .map(|s| match s {
462            OffloadSize::Static(sz) => *sz,
463            // NOTE(Sa4dUs): set `.offload_sizes` entry to 0 for sizes that we determine at runtime, just like clang
464            _ => 0,
465        })
466        .collect::<Vec<_>>();
467    let offload_sizes =
468        add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offload_sizes.{0}", symbol))
    })format!(".offload_sizes.{symbol}"), &actual_sizes);
469    let memtransfer_begin =
470        add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offload_maptypes.{0}.begin",
                symbol))
    })format!(".offload_maptypes.{symbol}.begin"), &transfer_to);
471    let memtransfer_kernel =
472        add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offload_maptypes.{0}.kernel",
                symbol))
    })format!(".offload_maptypes.{symbol}.kernel"), &transfer_kernel);
473    let memtransfer_end =
474        add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offload_maptypes.{0}.end",
                symbol))
    })format!(".offload_maptypes.{symbol}.end"), &transfer_from);
475
476    // Next: For each function, generate these three entries. A weak constant,
477    // the llvm.rodata entry name, and  the llvm_offload_entries value
478
479    let name = ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".{0}.region_id", symbol))
    })format!(".{symbol}.region_id");
480    let initializer = cx.get_const_i8(0);
481    let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
482
483    let c_entry_name = CString::new(symbol.clone()).unwrap();
484    let c_val = c_entry_name.as_bytes_with_nul();
485    let offload_entry_name = ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offloading.entry_name.{0}",
                symbol))
    })format!(".offloading.entry_name.{symbol}");
486
487    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
488    let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
489    llvm::set_alignment(llglobal, Align::ONE);
490    llvm::set_section(llglobal, c".llvm.rodata.offloading");
491
492    let name = ::alloc::__export::must_use({
        ::alloc::fmt::format(format_args!(".offloading.entry.{0}", symbol))
    })format!(".offloading.entry.{symbol}");
493
494    // See the __tgt_offload_entry documentation above.
495    let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
496
497    let initializer = crate::common::named_struct(offload_entry_ty, &elems);
498    let c_name = CString::new(name).unwrap();
499    let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
500    llvm::set_global_constant(offload_entry, true);
501    llvm::set_linkage(offload_entry, WeakAnyLinkage);
502    llvm::set_initializer(offload_entry, initializer);
503    llvm::set_alignment(offload_entry, Align::EIGHT);
504    let c_section_name = CString::new("llvm_offload_entries").unwrap();
505    llvm::set_section(offload_entry, &c_section_name);
506
507    cx.add_compiler_used_global(offload_entry);
508
509    let result = OffloadKernelGlobals {
510        offload_sizes,
511        memtransfer_begin,
512        memtransfer_kernel,
513        memtransfer_end,
514        region_id,
515    };
516
517    cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
518
519    result
520}
521
522fn declare_offload_fn<'ll>(
523    cx: &CodegenCx<'ll, '_>,
524    name: &str,
525    ty: &'ll llvm::Type,
526) -> &'ll llvm::Value {
527    crate::declare::declare_simple_fn(
528        cx,
529        name,
530        llvm::CallConv::CCallConv,
531        llvm::UnnamedAddr::No,
532        llvm::Visibility::Default,
533        ty,
534    )
535}
536
537pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
538    match cx.type_kind(ty) {
539        TypeKind::Half
540        | TypeKind::Float
541        | TypeKind::Double
542        | TypeKind::X86_FP80
543        | TypeKind::FP128
544        | TypeKind::PPC_FP128 => cx.float_width(ty) as u64,
545        TypeKind::Integer => cx.int_width(ty),
546        other => ::rustc_middle::util::bug::bug_fmt(format_args!("scalar_width was called on a non scalar type {0:?}",
        other))bug!("scalar_width was called on a non scalar type {other:?}"),
547    }
548}
549
550fn get_runtime_size<'ll, 'tcx>(
551    builder: &mut Builder<'_, 'll, 'tcx>,
552    args: &[&'ll Value],
553    index: usize,
554    meta: &OffloadMetadata,
555) -> &'ll Value {
556    match meta.payload_size {
557        OffloadSize::Slice { element_size } => {
558            let length_idx = index + 1;
559            let length = args[length_idx];
560            let length_i64 = builder.intcast(length, builder.cx.type_i64(), false);
561            builder.mul(length_i64, builder.cx.get_const_i64(element_size))
562        }
563        _ => ::rustc_middle::util::bug::bug_fmt(format_args!("unexpected offload size {0:?}",
        meta.payload_size))bug!("unexpected offload size {:?}", meta.payload_size),
564    }
565}
566
567// For each kernel *call*, we now use some of our previous declared globals to move data to and from
568// the gpu. For now, we only handle the data transfer part of it.
569// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
570// Since in our frontend users (by default) don't have to specify data transfer, this is something
571// we should optimize in the future! In some cases we can directly zero-allocate on the device and
572// only move data back, or if something is immutable, we might only copy it to the device.
573//
574// Current steps:
575// 0. Alloca some variables for the following steps
576// 1. set insert point before kernel call.
577// 2. generate all the GEPS and stores, to be used in 3)
578// 3. generate __tgt_target_data_begin calls to move data to the GPU
579//
580// unchanged: keep kernel call. Later move the kernel to the GPU
581//
582// 4. set insert point after kernel call.
583// 5. generate all the GEPS and stores, to be used in 6)
584// 6. generate __tgt_target_data_end calls to move data from the GPU
585pub(crate) fn gen_call_handling<'ll, 'tcx>(
586    builder: &mut Builder<'_, 'll, 'tcx>,
587    offload_data: &OffloadKernelGlobals<'ll>,
588    args: &[&'ll Value],
589    types: &[&Type],
590    metadata: &[OffloadMetadata],
591    offload_globals: &OffloadGlobals<'ll>,
592    offload_dims: &OffloadKernelDims<'ll>,
593    dyn_cache: &'ll Value,
594) {
595    let cx = builder.cx;
596    let OffloadKernelGlobals {
597        offload_sizes,
598        memtransfer_begin,
599        memtransfer_kernel,
600        memtransfer_end,
601        region_id,
602    } = offload_data;
603    let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
604        offload_dims;
605
606    let has_dynamic = metadata.iter().any(|m| !#[allow(non_exhaustive_omitted_patterns)] match m.payload_size {
    OffloadSize::Static(_) => true,
    _ => false,
}matches!(m.payload_size, OffloadSize::Static(_)));
607
608    let tgt_decl = offload_globals.launcher_fn;
609    let tgt_target_kernel_ty = offload_globals.launcher_ty;
610
611    let tgt_kernel_decl = offload_globals.kernel_args_ty;
612    let begin_mapper_decl = offload_globals.begin_mapper;
613    let end_mapper_decl = offload_globals.end_mapper;
614    let fn_ty = offload_globals.mapper_fn_ty;
615
616    let num_args = types.len() as u64;
617    let bb = builder.llbb();
618
619    // Step 0)
620    unsafe {
621        llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
622    }
623
624    let ty = cx.type_array(cx.type_ptr(), num_args);
625    // Baseptr are just the input pointer to the kernel, stored in a local alloca
626    let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
627    // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
628    let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
629    // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
630    let ty2 = cx.type_array(cx.type_i64(), num_args);
631
632    let a4 = if has_dynamic {
633        let alloc = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
634
635        builder.memcpy(
636            alloc,
637            Align::EIGHT,
638            offload_sizes,
639            Align::EIGHT,
640            cx.get_const_i64(8 * args.len() as u64),
641            MemFlags::empty(),
642            None,
643        );
644
645        alloc
646    } else {
647        offload_sizes
648    };
649
650    //%kernel_args = alloca %struct.__tgt_kernel_arguments, align 8
651    let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
652
653    // Step 1)
654    unsafe {
655        llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
656    }
657
658    // Now we allocate once per function param, a copy to be passed to one of our maps.
659    let mut vals = ::alloc::vec::Vec::new()vec![];
660    let mut geps = ::alloc::vec::Vec::new()vec![];
661    let i32_0 = cx.get_const_i32(0);
662    for &v in args {
663        let ty = cx.val_ty(v);
664        let ty_kind = cx.type_kind(ty);
665        let (base_val, gep_base) = match ty_kind {
666            TypeKind::Pointer => (v, v),
667            TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
668                // FIXME(Sa4dUs): check for `f128` support, latest NVIDIA cards support it
669                let num_bits = scalar_width(cx, ty);
670
671                let bb = builder.llbb();
672                unsafe {
673                    llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, builder.llfn());
674                }
675                let addr = builder.direct_alloca(cx.type_i64(), Align::EIGHT, "addr");
676                unsafe {
677                    llvm::LLVMPositionBuilderAtEnd(builder.llbuilder, bb);
678                }
679
680                let cast = builder.bitcast(v, cx.type_ix(num_bits));
681                let value = builder.zext(cast, cx.type_i64());
682                builder.store(value, addr, Align::EIGHT);
683                (value, addr)
684            }
685            other => ::rustc_middle::util::bug::bug_fmt(format_args!("offload does not support {0:?}",
        other))bug!("offload does not support {other:?}"),
686        };
687
688        let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]);
689
690        vals.push(base_val);
691        geps.push(gep);
692    }
693
694    for i in 0..num_args {
695        let idx = cx.get_const_i32(i);
696        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
697        builder.store(vals[i as usize], gep1, Align::EIGHT);
698        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
699        builder.store(geps[i as usize], gep2, Align::EIGHT);
700
701        if !#[allow(non_exhaustive_omitted_patterns)] match metadata[i as
                usize].payload_size {
    OffloadSize::Static(_) => true,
    _ => false,
}matches!(metadata[i as usize].payload_size, OffloadSize::Static(_)) {
702            let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
703            let size_val = get_runtime_size(builder, args, i as usize, &metadata[i as usize]);
704            builder.store(size_val, gep3, Align::EIGHT);
705        }
706    }
707
708    // For now we have a very simplistic indexing scheme into our
709    // offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
710    fn get_geps<'ll, 'tcx>(
711        builder: &mut Builder<'_, 'll, 'tcx>,
712        ty: &'ll Type,
713        ty2: &'ll Type,
714        a1: &'ll Value,
715        a2: &'ll Value,
716        a4: &'ll Value,
717        is_dynamic: bool,
718    ) -> [&'ll Value; 3] {
719        let cx = builder.cx;
720        let i32_0 = cx.get_const_i32(0);
721
722        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
723        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
724        let gep3 = if is_dynamic { builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]) } else { a4 };
725        [gep1, gep2, gep3]
726    }
727
728    fn generate_mapper_call<'ll, 'tcx>(
729        builder: &mut Builder<'_, 'll, 'tcx>,
730        geps: [&'ll Value; 3],
731        o_type: &'ll Value,
732        fn_to_call: &'ll Value,
733        fn_ty: &'ll Type,
734        num_args: u64,
735        s_ident_t: &'ll Value,
736    ) {
737        let cx = builder.cx;
738        let nullptr = cx.const_null(cx.type_ptr());
739        let i64_max = cx.get_const_i64(u64::MAX);
740        let num_args = cx.get_const_i32(num_args);
741        let args =
742            ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type,
                nullptr, nullptr]))vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
743        builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
744    }
745
746    // Step 2)
747    let s_ident_t = offload_globals.ident_t_global;
748    let geps = get_geps(builder, ty, ty2, a1, a2, a4, has_dynamic);
749    generate_mapper_call(
750        builder,
751        geps,
752        memtransfer_begin,
753        begin_mapper_decl,
754        fn_ty,
755        num_args,
756        s_ident_t,
757    );
758    let values = KernelArgsTy::new(
759        &cx,
760        num_args,
761        memtransfer_kernel,
762        geps,
763        workgroup_dims,
764        thread_dims,
765        dyn_cache,
766    );
767
768    // Step 3)
769    // Here we fill the KernelArgsTy, see the documentation above
770    for (i, value) in values.iter().enumerate() {
771        let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
772        let name = std::ffi::CString::new(value.1).unwrap();
773        llvm::set_value_name(ptr, &name.as_bytes());
774
775        builder.store(value.2, ptr, value.0);
776    }
777
778    let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
        [s_ident_t, cx.get_const_i64(u64::MAX), num_workgroups,
                threads_per_block, region_id, a5]))vec![
779        s_ident_t,
780        // FIXME(offload) give users a way to select which GPU to use.
781        cx.get_const_i64(u64::MAX), // MAX == -1.
782        num_workgroups,
783        threads_per_block,
784        region_id,
785        a5,
786    ];
787    builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
788    // %41 = call i32 @__tgt_target_kernel(ptr @1, i64 -1, i32 2097152, i32 256, ptr @.kernel_1.region_id, ptr %kernel_args)
789
790    // Step 4)
791    let geps = get_geps(builder, ty, ty2, a1, a2, a4, has_dynamic);
792    generate_mapper_call(
793        builder,
794        geps,
795        memtransfer_end,
796        end_mapper_decl,
797        fn_ty,
798        num_args,
799        s_ident_t,
800    );
801}