rustc_codegen_llvm/builder/
gpu_offload.rs1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::back::write::CodegenContext;
6use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7
8use crate::builder::SBuilder;
9use crate::common::AsCCharPtr;
10use crate::llvm::AttributePlace::Function;
11use crate::llvm::{self, Linkage, Type, Value};
12use crate::{LlvmCodegenBackend, SimpleCx, attributes};
13
14pub(crate) fn handle_gpu_code<'ll>(
15 _cgcx: &CodegenContext<LlvmCodegenBackend>,
16 cx: &'ll SimpleCx<'_>,
17) {
18 let mut o_types = vec![];
20 let mut kernels = vec![];
21 let offload_entry_ty = add_tgt_offload_entry(&cx);
22 for num in 0..9 {
23 let kernel = cx.get_function(&format!("kernel_{num}"));
24 if let Some(kernel) = kernel {
25 o_types.push(gen_define_handling(&cx, kernel, offload_entry_ty, num));
26 kernels.push(kernel);
27 }
28 }
29
30 gen_call_handling(&cx, &kernels, &o_types);
31}
32
33fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
37 let unknown_txt = ";unknown;unknown;0;0;;";
39 let c_entry_name = CString::new(unknown_txt).unwrap();
40 let c_val = c_entry_name.as_bytes_with_nul();
41 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
42 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
43 llvm::set_alignment(at_zero, Align::ONE);
44
45 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
47 let struct_elems = vec![
48 cx.get_const_i32(0),
49 cx.get_const_i32(2),
50 cx.get_const_i32(0),
51 cx.get_const_i32(22),
52 at_zero,
53 ];
54 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
55 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
56 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
57 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
58 llvm::set_alignment(at_one, Align::EIGHT);
59 at_one
60}
61
62pub(crate) fn add_tgt_offload_entry<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
63 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
64 let tptr = cx.type_ptr();
65 let ti64 = cx.type_i64();
66 let ti32 = cx.type_i32();
67 let ti16 = cx.type_i16();
68 let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
82 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
83 offload_entry_ty
84}
85
86fn gen_tgt_kernel_global<'ll>(cx: &'ll SimpleCx<'_>) {
87 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
88 let tptr = cx.type_ptr();
89 let ti64 = cx.type_i64();
90 let ti32 = cx.type_i32();
91 let tarr = cx.type_array(ti32, 3);
92
93 let kernel_elements =
118 vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
119
120 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
121 cx.declare_global("my_struct_global2", kernel_arguments_ty);
124}
125
126fn gen_tgt_data_mappers<'ll>(
127 cx: &'ll SimpleCx<'_>,
128) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
129 let tptr = cx.type_ptr();
130 let ti64 = cx.type_i64();
131 let ti32 = cx.type_i32();
132
133 let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
134 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
135 let mapper_begin = "__tgt_target_data_begin_mapper";
136 let mapper_update = "__tgt_target_data_update_mapper";
137 let mapper_end = "__tgt_target_data_end_mapper";
138 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
139 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
140 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
141
142 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
143 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
144 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
145 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
146
147 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
148}
149
150fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
151 let ti64 = cx.type_i64();
152 let mut size_val = Vec::with_capacity(vals.len());
153 for &val in vals {
154 size_val.push(cx.get_const_i64(val));
155 }
156 let initializer = cx.const_array(ti64, &size_val);
157 add_unnamed_global(cx, name, initializer, PrivateLinkage)
158}
159
160pub(crate) fn add_unnamed_global<'ll>(
161 cx: &SimpleCx<'ll>,
162 name: &str,
163 initializer: &'ll llvm::Value,
164 l: Linkage,
165) -> &'ll llvm::Value {
166 let llglobal = add_global(cx, name, initializer, l);
167 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
168 llglobal
169}
170
171pub(crate) fn add_global<'ll>(
172 cx: &SimpleCx<'ll>,
173 name: &str,
174 initializer: &'ll llvm::Value,
175 l: Linkage,
176) -> &'ll llvm::Value {
177 let c_name = CString::new(name).unwrap();
178 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
179 llvm::set_global_constant(llglobal, true);
180 llvm::set_linkage(llglobal, l);
181 llvm::set_initializer(llglobal, initializer);
182 llglobal
183}
184
185fn gen_define_handling<'ll>(
186 cx: &'ll SimpleCx<'_>,
187 kernel: &'ll llvm::Value,
188 offload_entry_ty: &'ll llvm::Type,
189 num: i64,
190) -> &'ll llvm::Value {
191 let types = cx.func_params_types(cx.get_type_of_global(kernel));
192 let num_ptr_types = types
195 .iter()
196 .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
197 .count();
198
199 add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
204 let o_types =
209 add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{num}"), &vec![3; num_ptr_types]);
210 let name = format!(".kernel_{num}.region_id");
214 let initializer = cx.get_const_i8(0);
215 let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
216
217 let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
218 let c_val = c_entry_name.as_bytes_with_nul();
219 let offload_entry_name = format!(".offloading.entry_name.{num}");
220
221 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
222 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
223 llvm::set_alignment(llglobal, Align::ONE);
224 llvm::set_section(llglobal, c".llvm.rodata.offloading");
225
226 let name = format!(".offloading.entry.kernel_{num}");
228
229 let reserved = cx.get_const_i64(0);
231 let version = cx.get_const_i16(1);
232 let kind = cx.get_const_i16(1);
233 let flags = cx.get_const_i32(0);
234 let size = cx.get_const_i64(0);
235 let data = cx.get_const_i64(0);
236 let aux_addr = cx.const_null(cx.type_ptr());
237 let elems = vec![reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr];
238
239 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
240 let c_name = CString::new(name).unwrap();
241 let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
242 llvm::set_global_constant(llglobal, true);
243 llvm::set_linkage(llglobal, WeakAnyLinkage);
244 llvm::set_initializer(llglobal, initializer);
245 llvm::set_alignment(llglobal, Align::ONE);
246 let c_section_name = CString::new(".omp_offloading_entries").unwrap();
247 llvm::set_section(llglobal, &c_section_name);
248 o_types
249}
250
251fn declare_offload_fn<'ll>(
252 cx: &'ll SimpleCx<'_>,
253 name: &str,
254 ty: &'ll llvm::Type,
255) -> &'ll llvm::Value {
256 crate::declare::declare_simple_fn(
257 cx,
258 name,
259 llvm::CallConv::CCallConv,
260 llvm::UnnamedAddr::No,
261 llvm::Visibility::Default,
262 ty,
263 )
264}
265
266fn gen_call_handling<'ll>(
287 cx: &'ll SimpleCx<'_>,
288 _kernels: &[&'ll llvm::Value],
289 o_types: &[&'ll llvm::Value],
290) {
291 let tptr = cx.type_ptr();
293 let ti32 = cx.type_i32();
294 let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
295 let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
296 cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
297
298 gen_tgt_kernel_global(&cx);
299 let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
300
301 let main_fn = cx.get_function("main");
302 let Some(main_fn) = main_fn else { return };
303 let kernel_name = "kernel_1";
304 let call = unsafe {
305 llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
306 };
307 let Some(kernel_call) = call else {
308 return;
309 };
310 let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
311 let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
312 let mut builder = SBuilder::build(cx, kernel_call_bb);
313
314 let types = cx.func_params_types(cx.get_type_of_global(called));
315 let num_args = types.len() as u64;
316
317 unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
321
322 let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
323
324 let ty = cx.type_array(cx.type_ptr(), num_args);
325 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
327 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
329 let ty2 = cx.type_array(cx.type_i64(), num_args);
331 let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
332 let mut vals = vec![];
334 let mut geps = vec![];
335 let i32_0 = cx.get_const_i32(0);
336 for (index, in_ty) in types.iter().enumerate() {
337 let p = llvm::get_param(called, index as u32);
339 let name = llvm::get_value_name(p);
340 let name = str::from_utf8(&name).unwrap();
341 let arg_name = format!("{name}.addr");
342 let alloca = builder.direct_alloca(in_ty, Align::EIGHT, &arg_name);
343
344 builder.store(p, alloca, Align::EIGHT);
345 let val = builder.load(in_ty, alloca, Align::EIGHT);
346 let gep = builder.inbounds_gep(cx.type_f32(), val, &[i32_0]);
347 vals.push(val);
348 geps.push(gep);
349 }
350
351 unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
353 builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
354
355 let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
356 let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
357 let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
358 let init_ty = cx.type_func(&[], cx.type_void());
359 let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
360
361 builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
363 builder.call(init_ty, init_rtls_decl, &[], None);
365
366 for i in 0..num_args {
367 let idx = cx.get_const_i32(i);
368 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
369 builder.store(vals[i as usize], gep1, Align::EIGHT);
370 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
371 builder.store(geps[i as usize], gep2, Align::EIGHT);
372 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
373 builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
377 }
378
379 fn get_geps<'a, 'll>(
382 builder: &mut SBuilder<'a, 'll>,
383 cx: &'ll SimpleCx<'ll>,
384 ty: &'ll Type,
385 ty2: &'ll Type,
386 a1: &'ll Value,
387 a2: &'ll Value,
388 a4: &'ll Value,
389 ) -> (&'ll Value, &'ll Value, &'ll Value) {
390 let i32_0 = cx.get_const_i32(0);
391
392 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
393 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
394 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
395 (gep1, gep2, gep3)
396 }
397
398 fn generate_mapper_call<'a, 'll>(
399 builder: &mut SBuilder<'a, 'll>,
400 cx: &'ll SimpleCx<'ll>,
401 geps: (&'ll Value, &'ll Value, &'ll Value),
402 o_type: &'ll Value,
403 fn_to_call: &'ll Value,
404 fn_ty: &'ll Type,
405 num_args: u64,
406 s_ident_t: &'ll Value,
407 ) {
408 let nullptr = cx.const_null(cx.type_ptr());
409 let i64_max = cx.get_const_i64(u64::MAX);
410 let num_args = cx.get_const_i32(num_args);
411 let args =
412 vec![s_ident_t, i64_max, num_args, geps.0, geps.1, geps.2, o_type, nullptr, nullptr];
413 builder.call(fn_ty, fn_to_call, &args, None);
414 }
415
416 let s_ident_t = generate_at_one(&cx);
418 let o = o_types[0];
419 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
420 generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
421
422 unsafe { llvm::LLVMRustPositionAfter(builder.llbuilder, kernel_call) };
428
429 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
430 generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
431
432 builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
433
434 }