rustc_codegen_llvm/builder/
gpu_offload.rs

1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::back::write::CodegenContext;
6use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7
8use crate::builder::SBuilder;
9use crate::common::AsCCharPtr;
10use crate::llvm::AttributePlace::Function;
11use crate::llvm::{self, Linkage, Type, Value};
12use crate::{LlvmCodegenBackend, SimpleCx, attributes};
13
14pub(crate) fn handle_gpu_code<'ll>(
15    _cgcx: &CodegenContext<LlvmCodegenBackend>,
16    cx: &'ll SimpleCx<'_>,
17) {
18    // The offload memory transfer type for each kernel
19    let mut o_types = vec![];
20    let mut kernels = vec![];
21    let offload_entry_ty = add_tgt_offload_entry(&cx);
22    for num in 0..9 {
23        let kernel = cx.get_function(&format!("kernel_{num}"));
24        if let Some(kernel) = kernel {
25            o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
26            kernels.push(kernel);
27        }
28    }
29
30    gen_call_handling(&cx, &kernels, &o_types);
31}
32
33// What is our @1 here? A magic global, used in our data_{begin/update/end}_mapper:
34// @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
35// @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
36fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
37    // @0 = private unnamed_addr constant [23 x i8] c";unknown;unknown;0;0;;\00", align 1
38    let unknown_txt = ";unknown;unknown;0;0;;";
39    let c_entry_name = CString::new(unknown_txt).unwrap();
40    let c_val = c_entry_name.as_bytes_with_nul();
41    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
42    let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
43    llvm::set_alignment(at_zero, Align::ONE);
44
45    // @1 = private unnamed_addr constant %struct.ident_t { i32 0, i32 2, i32 0, i32 22, ptr @0 }, align 8
46    let struct_ident_ty = cx.type_named_struct("struct.ident_t");
47    let struct_elems = vec![
48        cx.get_const_i32(0),
49        cx.get_const_i32(2),
50        cx.get_const_i32(0),
51        cx.get_const_i32(22),
52        at_zero,
53    ];
54    let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
55    let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
56    cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
57    let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
58    llvm::set_alignment(at_one, Align::EIGHT);
59    at_one
60}
61
62pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
63    let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
64    let tptr = cx.type_ptr();
65    let ti64 = cx.type_i64();
66    let ti32 = cx.type_i32();
67    let ti16 = cx.type_i16();
68    // For each kernel to run on the gpu, we will later generate one entry of this type.
69    // copied from LLVM
70    // typedef struct {
71    //   uint64_t Reserved;
72    //   uint16_t Version;
73    //   uint16_t Kind;
74    //   uint32_t Flags; Flags associated with the entry (see Target Region Entry Flags)
75    //   void *Address; Address of global symbol within device image (function or global)
76    //   char *SymbolName;
77    //   uint64_t Size; Size of the entry info (0 if it is a function)
78    //   uint64_t Data;
79    //   void *AuxAddr;
80    // } __tgt_offload_entry;
81    let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
82    cx.set_struct_body(offload_entry_ty, &entry_elements, false);
83    offload_entry_ty
84}
85
86fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
87    let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
88    let tptr = cx.type_ptr();
89    let ti64 = cx.type_i64();
90    let ti32 = cx.type_i32();
91    let tarr = cx.type_array(ti32, 3);
92
93    // Taken from the LLVM APITypes.h declaration:
94    //struct KernelArgsTy {
95    //  uint32_t Version = 0; // Version of this struct for ABI compatibility.
96    //  uint32_t NumArgs = 0; // Number of arguments in each input pointer.
97    //  void **ArgBasePtrs =
98    //      nullptr;                 // Base pointer of each argument (e.g. a struct).
99    //  void **ArgPtrs = nullptr;    // Pointer to the argument data.
100    //  int64_t *ArgSizes = nullptr; // Size of the argument data in bytes.
101    //  int64_t *ArgTypes = nullptr; // Type of the data (e.g. to / from).
102    //  void **ArgNames = nullptr;   // Name of the data for debugging, possibly null.
103    //  void **ArgMappers = nullptr; // User-defined mappers, possibly null.
104    //  uint64_t Tripcount =
105    //      0; // Tripcount for the teams / distribute loop, 0 otherwise.
106    //  struct {
107    //    uint64_t NoWait : 1; // Was this kernel spawned with a `nowait` clause.
108    //    uint64_t IsCUDA : 1; // Was this kernel spawned via CUDA.
109    //    uint64_t Unused : 62;
110    //  } Flags = {0, 0, 0};
111    //  // The number of teams (for x,y,z dimension).
112    //  uint32_t NumTeams[3] = {0, 0, 0};
113    //  // The number of threads (for x,y,z dimension).
114    //  uint32_t ThreadLimit[3] = {0, 0, 0};
115    //  uint32_t DynCGroupMem = 0; // Amount of dynamic cgroup memory requested.
116    //};
117    let kernel_elements =
118        vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
119
120    cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
121    // For now we don't handle kernels, so for now we just add a global dummy
122    // to make sure that the __tgt_offload_entry is defined and handled correctly.
123    cx.declare_global("my_struct_global2", kernel_arguments_ty);
124}
125
126fn gen_tgt_data_mappers<'ll>(
127    cx: &'ll SimpleCx<'_>,
128) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
129    let tptr = cx.type_ptr();
130    let ti64 = cx.type_i64();
131    let ti32 = cx.type_i32();
132
133    let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
134    let mapper_fn_ty = cx.type_func(&args, cx.type_void());
135    let mapper_begin = "__tgt_target_data_begin_mapper";
136    let mapper_update = "__tgt_target_data_update_mapper";
137    let mapper_end = "__tgt_target_data_end_mapper";
138    let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
139    let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
140    let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
141
142    let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
143    attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
144    attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
145    attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
146
147    (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
148}
149
150fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
151    let ti64 = cx.type_i64();
152    let mut size_val = Vec::with_capacity(vals.len());
153    for &val in vals {
154        size_val.push(cx.get_const_i64(val));
155    }
156    let initializer = cx.const_array(ti64, &size_val);
157    add_unnamed_global(cx, name, initializer, PrivateLinkage)
158}
159
160pub(crate) fn add_unnamed_global<'ll>(
161    cx: &SimpleCx<'ll>,
162    name: &str,
163    initializer: &'ll llvm::Value,
164    l: Linkage,
165) -> &'ll llvm::Value {
166    let llglobal = add_global(cx, name, initializer, l);
167    llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
168    llglobal
169}
170
171pub(crate) fn add_global<'ll>(
172    cx: &SimpleCx<'ll>,
173    name: &str,
174    initializer: &'ll llvm::Value,
175    l: Linkage,
176) -> &'ll llvm::Value {
177    let c_name = CString::new(name).unwrap();
178    let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
179    llvm::set_global_constant(llglobal, true);
180    llvm::set_linkage(llglobal, l);
181    llvm::set_initializer(llglobal, initializer);
182    llglobal
183}
184
185fn gen_define_handling<'ll>(
186    cx: &'ll SimpleCx<'_>,
187    kernel: &'ll llvm::Value,
188    offload_entry_ty: &'ll llvm::Type,
189    num: i64,
190) -> &'ll llvm::Value {
191    let types = cx.func_params_types(cx.get_type_of_global(kernel));
192    // It seems like non-pointer values are automatically mapped. So here, we focus on pointer (or
193    // reference) types.
194    let num_ptr_types = types
195        .iter()
196        .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
197        .count();
198
199    // We do not know their size anymore at this level, so hardcode a placeholder.
200    // A follow-up pr will track these from the frontend, where we still have Rust types.
201    // Then, we will be able to figure out that e.g. `&[f32;256]` will result in 4*256 bytes.
202    // I decided that 1024 bytes is a great placeholder value for now.
203    add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
204    // Here we figure out whether something needs to be copied to the gpu (=1), from the gpu (=2),
205    // or both to and from the gpu (=3). Other values shouldn't affect us for now.
206    // A non-mutable reference or pointer will be 1, an array that's not read, but fully overwritten
207    // will be 2. For now, everything is 3, until we have our frontend set up.
208    let o_types =
209        add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]);
210    // Next: For each function, generate these three entries. A weak constant,
211    // the llvm.rodata entry name, and  the omp_offloading_entries value
212
213    let name = format!(".kernel_{num}.region_id");
214    let initializer = cx.get_const_i8(0);
215    let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
216
217    let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
218    let c_val = c_entry_name.as_bytes_with_nul();
219    let offload_entry_name = format!(".offloading.entry_name.{num}");
220
221    let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
222    let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
223    llvm::set_alignment(llglobal, Align::ONE);
224    llvm::set_section(llglobal, c".llvm.rodata.offloading");
225
226    // Not actively used yet, for calling real kernels
227    let name = format!(".offloading.entry.kernel_{num}");
228
229    // See the __tgt_offload_entry documentation above.
230    let reserved = cx.get_const_i64(0);
231    let version = cx.get_const_i16(1);
232    let kind = cx.get_const_i16(1);
233    let flags = cx.get_const_i32(0);
234    let size = cx.get_const_i64(0);
235    let data = cx.get_const_i64(0);
236    let aux_addr = cx.const_null(cx.type_ptr());
237    let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr];
238
239    let initializer = crate::common::named_struct(offload_entry_ty, &elems);
240    let c_name = CString::new(name).unwrap();
241    let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
242    llvm::set_global_constant(llglobal, true);
243    llvm::set_linkage(llglobal, WeakAnyLinkage);
244    llvm::set_initializer(llglobal, initializer);
245    llvm::set_alignment(llglobal, Align::ONE);
246    let c_section_name = CString::new(".omp_offloading_entries").unwrap();
247    llvm::set_section(llglobal, &c_section_name);
248    o_types
249}
250
251fn declare_offload_fn<'ll>(
252    cx: &'ll SimpleCx<'_>,
253    name: &str,
254    ty: &'ll llvm::Type,
255) -> &'ll llvm::Value {
256    crate::declare::declare_simple_fn(
257        cx,
258        name,
259        llvm::CallConv::CCallConv,
260        llvm::UnnamedAddr::No,
261        llvm::Visibility::Default,
262        ty,
263    )
264}
265
266// For each kernel *call*, we now use some of our previous declared globals to move data to and from
267// the gpu. We don't have a proper frontend yet, so we assume that every call to a kernel function
268// from main is intended to run on the GPU. For now, we only handle the data transfer part of it.
269// If two consecutive kernels use the same memory, we still move it to the host and back to the gpu.
270// Since in our frontend users (by default) don't have to specify data transfer, this is something
271// we should optimize in the future! We also assume that everything should be copied back and forth,
272// but sometimes we can directly zero-allocate on the device and only move back, or if something is
273// immutable, we might only copy it to the device, but not back.
274//
275// Current steps:
276// 0. Alloca some variables for the following steps
277// 1. set insert point before kernel call.
278// 2. generate all the GEPS and stores, to be used in 3)
279// 3. generate __tgt_target_data_begin calls to move data to the GPU
280//
281// unchanged: keep kernel call. Later move the kernel to the GPU
282//
283// 4. set insert point after kernel call.
284// 5. generate all the GEPS and stores, to be used in 6)
285// 6. generate __tgt_target_data_end calls to move data from the GPU
286fn gen_call_handling<'ll>(
287    cx: &'ll SimpleCx<'_>,
288    _kernels: &[&'ll llvm::Value],
289    o_types: &[&'ll llvm::Value],
290) {
291    // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
292    let tptr = cx.type_ptr();
293    let ti32 = cx.type_i32();
294    let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
295    let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
296    cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
297
298    gen_tgt_kernel_global(&cx);
299    let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
300
301    let main_fn = cx.get_function("main");
302    let Some(main_fn) = main_fn else { return };
303    let kernel_name = "kernel_1";
304    let call = unsafe {
305        llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
306    };
307    let Some(kernel_call) = call else {
308        return;
309    };
310    let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
311    let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
312    let mut builder = SBuilder::build(cx, kernel_call_bb);
313
314    let types = cx.func_params_types(cx.get_type_of_global(called));
315    let num_args = types.len() as u64;
316
317    // Step 0)
318    // %struct.__tgt_bin_desc = type { i32, ptr, ptr, ptr }
319    // %6 = alloca %struct.__tgt_bin_desc, align 8
320    unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
321
322    let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
323
324    let ty = cx.type_array(cx.type_ptr(), num_args);
325    // Baseptr are just the input pointer to the kernel, stored in a local alloca
326    let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
327    // Ptrs are the result of a gep into the baseptr, at least for our trivial types.
328    let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
329    // These represent the sizes in bytes, e.g. the entry for `&[f64; 16]` will be 8*16.
330    let ty2 = cx.type_array(cx.type_i64(), num_args);
331    let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
332    // Now we allocate once per function param, a copy to be passed to one of our maps.
333    let mut vals = vec![];
334    let mut geps = vec![];
335    let i32_0 = cx.get_const_i32(0);
336    for (index, in_ty) in types.iter().enumerate() {
337        // get function arg, store it into the alloca, and read it.
338        let p = llvm::get_param(called, index as u32);
339        let name = llvm::get_value_name(p);
340        let name = str::from_utf8(&name).unwrap();
341        let arg_name = format!("{name}.addr");
342        let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);
343
344        builder.store(p, alloca, Align::EIGHT);
345        let val = builder.load(in_ty, alloca, Align::EIGHT);
346        let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
347        vals.push(val);
348        geps.push(gep);
349    }
350
351    // Step 1)
352    unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
353    builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
354
355    let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
356    let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
357    let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
358    let init_ty = cx.type_func(&[], cx.type_void());
359    let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
360
361    // call void @__tgt_register_lib(ptr noundef %6)
362    builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
363    // call void @__tgt_init_all_rtls()
364    builder.call(init_ty, init_rtls_decl, &[], None);
365
366    for i in 0..num_args {
367        let idx = cx.get_const_i32(i);
368        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
369        builder.store(vals[i as usize], gep1, Align::EIGHT);
370        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
371        builder.store(geps[i as usize], gep2, Align::EIGHT);
372        let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
373        // As mentioned above, we don't use Rust type information yet. So for now we will just
374        // assume that we have 1024 bytes, 256 f32 values.
375        // FIXME(offload): write an offload frontend and handle arbitrary types.
376        builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
377    }
378
379    // For now we have a very simplistic indexing scheme into our
380    // offload_{baseptrs,ptrs,sizes}. We will probably improve this along with our gpu frontend pr.
381    fn get_geps<'a, 'll>(
382        builder: &mut SBuilder<'a, 'll>,
383        cx: &'ll SimpleCx<'ll>,
384        ty: &'ll Type,
385        ty2: &'ll Type,
386        a1: &'ll Value,
387        a2: &'ll Value,
388        a4: &'ll Value,
389    ) -> (&'ll Value, &'ll Value, &'ll Value) {
390        let i32_0 = cx.get_const_i32(0);
391
392        let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
393        let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
394        let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
395        (gep1, gep2, gep3)
396    }
397
398    fn generate_mapper_call<'a, 'll>(
399        builder: &mut SBuilder<'a, 'll>,
400        cx: &'ll SimpleCx<'ll>,
401        geps: (&'ll Value, &'ll Value, &'ll Value),
402        o_type: &'ll Value,
403        fn_to_call: &'ll Value,
404        fn_ty: &'ll Type,
405        num_args: u64,
406        s_ident_t: &'ll Value,
407    ) {
408        let nullptr = cx.const_null(cx.type_ptr());
409        let i64_max = cx.get_const_i64(u64::MAX);
410        let num_args = cx.get_const_i32(num_args);
411        let args =
412            vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr];
413        builder.call(fn_ty, fn_to_call, &args, None);
414    }
415
416    // Step 2)
417    let s_ident_t = generate_at_one(&cx);
418    let o = o_types[0];
419    let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
420    generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
421
422    // Step 3)
423    // Here we will add code for the actual kernel launches in a follow-up PR.
424    // FIXME(offload): launch kernels
425
426    // Step 4)
427    unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
428
429    let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
430    generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
431
432    builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
433
434    // With this we generated the following begin and end mappers. We could easily generate the
435    // update mapper in an update.
436    // call void @__tgt_target_data_begin_mapper(ptr @1, i64 -1, i32 3, ptr %27, ptr %28, ptr %29, ptr @.offload_maptypes, ptr null, ptr null)
437    // call void @__tgt_target_data_update_mapper(ptr @1, i64 -1, i32 2, ptr %46, ptr %47, ptr %48, ptr @.offload_maptypes.1, ptr null, ptr null)
438    // call void @__tgt_target_data_end_mapper(ptr @1, i64 -1, i32 3, ptr %49, ptr %50, ptr %51, ptr @.offload_maptypes, ptr null, ptr null)
439}