1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
6use rustc_middle::ty::offload_meta::OffloadMetadata;
7
8use crate::builder::Builder;
9use crate::common::CodegenCx;
10use crate::llvm::AttributePlace::Function;
11use crate::llvm::{self, Linkage, Type, Value};
12use crate::{SimpleCx, attributes};
13
14pub(crate) struct OffloadGlobals<'ll> {
16 pub launcher_fn: &'ll llvm::Value,
17 pub launcher_ty: &'ll llvm::Type,
18
19 pub bin_desc: &'ll llvm::Type,
20
21 pub kernel_args_ty: &'ll llvm::Type,
22
23 pub offload_entry_ty: &'ll llvm::Type,
24
25 pub begin_mapper: &'ll llvm::Value,
26 pub end_mapper: &'ll llvm::Value,
27 pub mapper_fn_ty: &'ll llvm::Type,
28
29 pub ident_t_global: &'ll llvm::Value,
30
31 pub register_lib: &'ll llvm::Value,
32 pub unregister_lib: &'ll llvm::Value,
33 pub init_rtls: &'ll llvm::Value,
34}
35
36impl<'ll> OffloadGlobals<'ll> {
37 pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
38 let (launcher_fn, launcher_ty) = generate_launcher(cx);
39 let kernel_args_ty = KernelArgsTy::new_decl(cx);
40 let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
41 let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
42 let ident_t_global = generate_at_one(cx);
43
44 let tptr = cx.type_ptr();
45 let ti32 = cx.type_i32();
46 let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
47 let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
48 cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);
49
50 let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
51 let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
52 let init_ty = cx.type_func(&[], cx.type_void());
53 let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
54
55 OffloadGlobals {
56 launcher_fn,
57 launcher_ty,
58 bin_desc,
59 kernel_args_ty,
60 offload_entry_ty,
61 begin_mapper,
62 end_mapper,
63 mapper_fn_ty,
64 ident_t_global,
65 register_lib,
66 unregister_lib,
67 init_rtls,
68 }
69 }
70}
71
72fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
75 let tptr = cx.type_ptr();
76 let ti64 = cx.type_i64();
77 let ti32 = cx.type_i32();
78 let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
79 let tgt_fn_ty = cx.type_func(&args, ti32);
80 let name = "__tgt_target_kernel";
81 let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
82 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
83 attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
84 (tgt_decl, tgt_fn_ty)
85}
86
87pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
93 let unknown_txt = ";unknown;unknown;0;0;;";
94 let c_entry_name = CString::new(unknown_txt).unwrap();
95 let c_val = c_entry_name.as_bytes_with_nul();
96 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
97 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
98 llvm::set_alignment(at_zero, Align::ONE);
99
100 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
102 let struct_elems = vec![
103 cx.get_const_i32(0),
104 cx.get_const_i32(2),
105 cx.get_const_i32(0),
106 cx.get_const_i32(22),
107 at_zero,
108 ];
109 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
110 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
111 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
112 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
113 llvm::set_alignment(at_one, Align::EIGHT);
114 at_one
115}
116
117pub(crate) struct TgtOffloadEntry {
118 }
128
129impl TgtOffloadEntry {
130 pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
131 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
132 let tptr = cx.type_ptr();
133 let ti64 = cx.type_i64();
134 let ti32 = cx.type_i32();
135 let ti16 = cx.type_i16();
136 let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
139 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
140 offload_entry_ty
141 }
142
143 fn new<'ll>(
144 cx: &CodegenCx<'ll, '_>,
145 region_id: &'ll Value,
146 llglobal: &'ll Value,
147 ) -> [&'ll Value; 9] {
148 let reserved = cx.get_const_i64(0);
149 let version = cx.get_const_i16(1);
150 let kind = cx.get_const_i16(1);
151 let flags = cx.get_const_i32(0);
152 let size = cx.get_const_i64(0);
153 let data = cx.get_const_i64(0);
154 let aux_addr = cx.const_null(cx.type_ptr());
155 [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
156 }
157}
158
159struct KernelArgsTy {
161 }
183
184impl KernelArgsTy {
185 const OFFLOAD_VERSION: u64 = 3;
186 const FLAGS: u64 = 0;
187 const TRIPCOUNT: u64 = 0;
188 fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
189 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
190 let tptr = cx.type_ptr();
191 let ti64 = cx.type_i64();
192 let ti32 = cx.type_i32();
193 let tarr = cx.type_array(ti32, 3);
194
195 let kernel_elements =
196 vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
197
198 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
199 kernel_arguments_ty
200 }
201
202 fn new<'ll, 'tcx>(
203 cx: &CodegenCx<'ll, 'tcx>,
204 num_args: u64,
205 memtransfer_types: &'ll Value,
206 geps: [&'ll Value; 3],
207 ) -> [(Align, &'ll Value); 13] {
208 let four = Align::from_bytes(4).expect("4 Byte alignment should work");
209 let eight = Align::EIGHT;
210
211 let ti32 = cx.type_i32();
212 let ci32_0 = cx.get_const_i32(0);
213 [
214 (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
215 (four, cx.get_const_i32(num_args)),
216 (eight, geps[0]),
217 (eight, geps[1]),
218 (eight, geps[2]),
219 (eight, memtransfer_types),
220 (eight, cx.const_null(cx.type_ptr())), (eight, cx.const_null(cx.type_ptr())), (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
224 (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
225 (four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
226 (four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
227 (four, cx.get_const_i32(0)),
228 ]
229 }
230}
231
232#[derive(Copy, Clone)]
234pub(crate) struct OffloadKernelGlobals<'ll> {
235 pub offload_sizes: &'ll llvm::Value,
236 pub memtransfer_types: &'ll llvm::Value,
237 pub region_id: &'ll llvm::Value,
238 pub offload_entry: &'ll llvm::Value,
239}
240
241fn gen_tgt_data_mappers<'ll>(
242 cx: &CodegenCx<'ll, '_>,
243) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
244 let tptr = cx.type_ptr();
245 let ti64 = cx.type_i64();
246 let ti32 = cx.type_i32();
247
248 let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
249 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
250 let mapper_begin = "__tgt_target_data_begin_mapper";
251 let mapper_update = "__tgt_target_data_update_mapper";
252 let mapper_end = "__tgt_target_data_end_mapper";
253 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
254 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
255 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
256
257 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
258 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
259 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
260 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
261
262 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
263}
264
265fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
266 let ti64 = cx.type_i64();
267 let mut size_val = Vec::with_capacity(vals.len());
268 for &val in vals {
269 size_val.push(cx.get_const_i64(val));
270 }
271 let initializer = cx.const_array(ti64, &size_val);
272 add_unnamed_global(cx, name, initializer, PrivateLinkage)
273}
274
275pub(crate) fn add_unnamed_global<'ll>(
276 cx: &SimpleCx<'ll>,
277 name: &str,
278 initializer: &'ll llvm::Value,
279 l: Linkage,
280) -> &'ll llvm::Value {
281 let llglobal = add_global(cx, name, initializer, l);
282 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
283 llglobal
284}
285
286pub(crate) fn add_global<'ll>(
287 cx: &SimpleCx<'ll>,
288 name: &str,
289 initializer: &'ll llvm::Value,
290 l: Linkage,
291) -> &'ll llvm::Value {
292 let c_name = CString::new(name).unwrap();
293 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
294 llvm::set_global_constant(llglobal, true);
295 llvm::set_linkage(llglobal, l);
296 llvm::set_initializer(llglobal, initializer);
297 llglobal
298}
299
300pub(crate) fn gen_define_handling<'ll>(
304 cx: &CodegenCx<'ll, '_>,
305 metadata: &[OffloadMetadata],
306 types: &[&'ll Type],
307 symbol: String,
308 offload_globals: &OffloadGlobals<'ll>,
309) -> OffloadKernelGlobals<'ll> {
310 if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
311 return *entry;
312 }
313
314 let offload_entry_ty = offload_globals.offload_entry_ty;
315
316 let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
319 rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
320 _ => None,
321 });
322
323 let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
325 ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
326
327 let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
328 let memtransfer_types =
334 add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer);
335
336 let name = format!(".{symbol}.region_id");
340 let initializer = cx.get_const_i8(0);
341 let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
342
343 let c_entry_name = CString::new(symbol.clone()).unwrap();
344 let c_val = c_entry_name.as_bytes_with_nul();
345 let offload_entry_name = format!(".offloading.entry_name.{symbol}");
346
347 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
348 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
349 llvm::set_alignment(llglobal, Align::ONE);
350 llvm::set_section(llglobal, c".llvm.rodata.offloading");
351
352 let name = format!(".offloading.entry.{symbol}");
353
354 let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
356
357 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
358 let c_name = CString::new(name).unwrap();
359 let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
360 llvm::set_global_constant(offload_entry, true);
361 llvm::set_linkage(offload_entry, WeakAnyLinkage);
362 llvm::set_initializer(offload_entry, initializer);
363 llvm::set_alignment(offload_entry, Align::EIGHT);
364 let c_section_name = CString::new("llvm_offload_entries").unwrap();
365 llvm::set_section(offload_entry, &c_section_name);
366
367 let result =
368 OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
369
370 cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
371
372 result
373}
374
375fn declare_offload_fn<'ll>(
376 cx: &CodegenCx<'ll, '_>,
377 name: &str,
378 ty: &'ll llvm::Type,
379) -> &'ll llvm::Value {
380 crate::declare::declare_simple_fn(
381 cx,
382 name,
383 llvm::CallConv::CCallConv,
384 llvm::UnnamedAddr::No,
385 llvm::Visibility::Default,
386 ty,
387 )
388}
389
390pub(crate) fn gen_call_handling<'ll, 'tcx>(
410 builder: &mut Builder<'_, 'll, 'tcx>,
411 offload_data: &OffloadKernelGlobals<'ll>,
412 args: &[&'ll Value],
413 types: &[&Type],
414 metadata: &[OffloadMetadata],
415 offload_globals: &OffloadGlobals<'ll>,
416) {
417 let cx = builder.cx;
418 let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
419 offload_data;
420
421 let tgt_decl = offload_globals.launcher_fn;
422 let tgt_target_kernel_ty = offload_globals.launcher_ty;
423
424 let tgt_bin_desc = offload_globals.bin_desc;
426
427 let tgt_kernel_decl = offload_globals.kernel_args_ty;
428 let begin_mapper_decl = offload_globals.begin_mapper;
429 let end_mapper_decl = offload_globals.end_mapper;
430 let fn_ty = offload_globals.mapper_fn_ty;
431
432 let num_args = types.len() as u64;
433 let ip = unsafe { llvm::LLVMRustGetInsertPoint(&builder.llbuilder) };
434
435 for val in [offload_sizes, offload_entry] {
438 unsafe {
439 let dummy = llvm::LLVMBuildLoad2(
440 &builder.llbuilder,
441 llvm::LLVMTypeOf(val),
442 val,
443 b"dummy\0".as_ptr() as *const _,
444 );
445 llvm::LLVMSetVolatile(dummy, llvm::TRUE);
446 }
447 }
448
449 unsafe {
453 llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
454 }
455 let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
456
457 let ty = cx.type_array(cx.type_ptr(), num_args);
458 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
460 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
462 let ty2 = cx.type_array(cx.type_i64(), num_args);
464 let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
465
466 let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
468
469 unsafe {
471 llvm::LLVMRustRestoreInsertPoint(&builder.llbuilder, ip);
472 }
473 builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
474
475 let mut vals = vec![];
477 let mut geps = vec![];
478 let i32_0 = cx.get_const_i32(0);
479 for &v in args {
480 let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
481 vals.push(v);
482 geps.push(gep);
483 }
484
485 let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
486 let register_lib_decl = offload_globals.register_lib;
487 let unregister_lib_decl = offload_globals.unregister_lib;
488 let init_ty = cx.type_func(&[], cx.type_void());
489 let init_rtls_decl = offload_globals.init_rtls;
490
491 builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None);
494 builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
496
497 for i in 0..num_args {
498 let idx = cx.get_const_i32(i);
499 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
500 builder.store(vals[i as usize], gep1, Align::EIGHT);
501 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
502 builder.store(geps[i as usize], gep2, Align::EIGHT);
503 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
504 builder.store(cx.get_const_i64(metadata[i as usize].payload_size), gep3, Align::EIGHT);
506 }
507
508 fn get_geps<'ll, 'tcx>(
511 builder: &mut Builder<'_, 'll, 'tcx>,
512 ty: &'ll Type,
513 ty2: &'ll Type,
514 a1: &'ll Value,
515 a2: &'ll Value,
516 a4: &'ll Value,
517 ) -> [&'ll Value; 3] {
518 let cx = builder.cx;
519 let i32_0 = cx.get_const_i32(0);
520
521 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
522 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
523 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
524 [gep1, gep2, gep3]
525 }
526
527 fn generate_mapper_call<'ll, 'tcx>(
528 builder: &mut Builder<'_, 'll, 'tcx>,
529 geps: [&'ll Value; 3],
530 o_type: &'ll Value,
531 fn_to_call: &'ll Value,
532 fn_ty: &'ll Type,
533 num_args: u64,
534 s_ident_t: &'ll Value,
535 ) {
536 let cx = builder.cx;
537 let nullptr = cx.const_null(cx.type_ptr());
538 let i64_max = cx.get_const_i64(u64::MAX);
539 let num_args = cx.get_const_i32(num_args);
540 let args =
541 vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
542 builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
543 }
544
545 let s_ident_t = offload_globals.ident_t_global;
547 let geps = get_geps(builder, ty, ty2, a1, a2, a4);
548 generate_mapper_call(
549 builder,
550 geps,
551 memtransfer_types,
552 begin_mapper_decl,
553 fn_ty,
554 num_args,
555 s_ident_t,
556 );
557 let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
558
559 for (i, value) in values.iter().enumerate() {
562 let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
563 builder.store(value.1, ptr, value.0);
564 }
565
566 let args = vec![
567 s_ident_t,
568 cx.get_const_i64(u64::MAX), cx.get_const_i32(2097152),
572 cx.get_const_i32(256),
573 region_id,
574 a5,
575 ];
576 builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
577 let geps = get_geps(builder, ty, ty2, a1, a2, a4);
581 generate_mapper_call(
582 builder,
583 geps,
584 memtransfer_types,
585 end_mapper_decl,
586 fn_ty,
587 num_args,
588 s_ident_t,
589 );
590
591 builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None);
592}