1use std::ffi::CString;
2
3use bitflags::Flags;
4use llvm::Linkage::*;
5use rustc_abi::Align;
6use rustc_codegen_ssa::common::TypeKind;
7use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
8use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
9use rustc_middle::bug;
10use rustc_middle::ty::offload_meta::{MappingFlags, OffloadMetadata};
11
12use crate::builder::Builder;
13use crate::common::CodegenCx;
14use crate::llvm::AttributePlace::Function;
15use crate::llvm::{self, Linkage, Type, Value};
16use crate::{SimpleCx, attributes};
17
18pub(crate) struct OffloadGlobals<'ll> {
20 pub launcher_fn: &'ll llvm::Value,
21 pub launcher_ty: &'ll llvm::Type,
22
23 pub kernel_args_ty: &'ll llvm::Type,
24
25 pub offload_entry_ty: &'ll llvm::Type,
26
27 pub begin_mapper: &'ll llvm::Value,
28 pub end_mapper: &'ll llvm::Value,
29 pub mapper_fn_ty: &'ll llvm::Type,
30
31 pub ident_t_global: &'ll llvm::Value,
32}
33
34impl<'ll> OffloadGlobals<'ll> {
35 pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
36 let (launcher_fn, launcher_ty) = generate_launcher(cx);
37 let kernel_args_ty = KernelArgsTy::new_decl(cx);
38 let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
39 let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
40 let ident_t_global = generate_at_one(cx);
41
42 llvm::add_module_flag_u32(cx.llmod(), llvm::ModuleFlagMergeBehavior::Max, "openmp", 51);
45
46 OffloadGlobals {
47 launcher_fn,
48 launcher_ty,
49 kernel_args_ty,
50 offload_entry_ty,
51 begin_mapper,
52 end_mapper,
53 mapper_fn_ty,
54 ident_t_global,
55 }
56 }
57}
58
59pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
66 let register_lib_name = "__tgt_register_lib";
69 if cx.get_function(register_lib_name).is_some() {
70 return;
71 }
72
73 let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
74 let register_lib = declare_offload_fn(&cx, register_lib_name, reg_lib_decl);
75 let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
76
77 let ptr_null = cx.const_null(cx.type_ptr());
78 let const_struct = cx.const_struct(&[cx.get_const_i32(0), ptr_null, ptr_null, ptr_null], false);
79 let omp_descriptor =
80 add_global(cx, ".omp_offloading.descriptor", const_struct, InternalLinkage);
81 let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
85 let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
86
87 let init_ty = cx.type_func(&[], cx.type_void());
90 let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
91
92 let desc_ty = cx.type_func(&[], cx.type_void());
93 let reg_name = ".omp_offloading.descriptor_reg";
94 let unreg_name = ".omp_offloading.descriptor_unreg";
95 let desc_reg_fn = declare_offload_fn(cx, reg_name, desc_ty);
96 let desc_unreg_fn = declare_offload_fn(cx, unreg_name, desc_ty);
97 llvm::set_linkage(desc_reg_fn, InternalLinkage);
98 llvm::set_linkage(desc_unreg_fn, InternalLinkage);
99 llvm::set_section(desc_reg_fn, c".text.startup");
100 llvm::set_section(desc_unreg_fn, c".text.startup");
101
102 let bb = Builder::append_block(cx, desc_reg_fn, "entry");
110 let mut a = Builder::build(cx, bb);
111 a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
112 a.call(init_ty, None, None, init_rtls, &[], None, None);
113 a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
114 a.ret_void();
115
116 let bb = Builder::append_block(cx, desc_unreg_fn, "entry");
122 let mut a = Builder::build(cx, bb);
123 a.call(reg_lib_decl, None, None, unregister_lib, &[omp_descriptor], None, None);
124 a.ret_void();
125
126 let args = <[_]>::into_vec(::alloc::boxed::box_new([cx.get_const_i32(101), desc_reg_fn,
ptr_null]))vec![cx.get_const_i32(101), desc_reg_fn, ptr_null];
128 let const_struct = cx.const_struct(&args, false);
129 let arr = cx.const_array(cx.val_ty(const_struct), &[const_struct]);
130 add_global(cx, "llvm.global_ctors", arr, AppendingLinkage);
131}
132
133pub(crate) struct OffloadKernelDims<'ll> {
134 num_workgroups: &'ll Value,
135 threads_per_block: &'ll Value,
136 workgroup_dims: &'ll Value,
137 thread_dims: &'ll Value,
138}
139
140impl<'ll> OffloadKernelDims<'ll> {
141 pub(crate) fn from_operands<'tcx>(
142 builder: &mut Builder<'_, 'll, 'tcx>,
143 workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
144 thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
145 ) -> Self {
146 let cx = builder.cx;
147 let arr_ty = cx.type_array(cx.type_i32(), 3);
148 let four = Align::from_bytes(4).unwrap();
149
150 let OperandValue::Ref(place) = workgroup_op.val else {
151 ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
152 };
153 let workgroup_val = builder.load(arr_ty, place.llval, four);
154
155 let OperandValue::Ref(place) = thread_op.val else {
156 ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
157 };
158 let thread_val = builder.load(arr_ty, place.llval, four);
159
160 fn mul_dim3<'ll, 'tcx>(
161 builder: &mut Builder<'_, 'll, 'tcx>,
162 arr: &'ll Value,
163 ) -> &'ll Value {
164 let x = builder.extract_value(arr, 0);
165 let y = builder.extract_value(arr, 1);
166 let z = builder.extract_value(arr, 2);
167
168 let xy = builder.mul(x, y);
169 builder.mul(xy, z)
170 }
171
172 let num_workgroups = mul_dim3(builder, workgroup_val);
173 let threads_per_block = mul_dim3(builder, thread_val);
174
175 OffloadKernelDims {
176 workgroup_dims: workgroup_val,
177 thread_dims: thread_val,
178 num_workgroups,
179 threads_per_block,
180 }
181 }
182}
183
184fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
187 let tptr = cx.type_ptr();
188 let ti64 = cx.type_i64();
189 let ti32 = cx.type_i32();
190 let args = <[_]>::into_vec(::alloc::boxed::box_new([tptr, ti64, ti32, ti32, tptr, tptr]))vec![tptr, ti64, ti32, ti32, tptr, tptr];
191 let tgt_fn_ty = cx.type_func(&args, ti32);
192 let name = "__tgt_target_kernel";
193 let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
194 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
195 attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
196 (tgt_decl, tgt_fn_ty)
197}
198
199pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
205 let unknown_txt = ";unknown;unknown;0;0;;";
206 let c_entry_name = CString::new(unknown_txt).unwrap();
207 let c_val = c_entry_name.as_bytes_with_nul();
208 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
209 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
210 llvm::set_alignment(at_zero, Align::ONE);
211
212 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
214 let struct_elems = <[_]>::into_vec(::alloc::boxed::box_new([cx.get_const_i32(0),
cx.get_const_i32(2), cx.get_const_i32(0),
cx.get_const_i32(22), at_zero]))vec![
215 cx.get_const_i32(0),
216 cx.get_const_i32(2),
217 cx.get_const_i32(0),
218 cx.get_const_i32(22),
219 at_zero,
220 ];
221 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
222 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
223 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
224 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
225 llvm::set_alignment(at_one, Align::EIGHT);
226 at_one
227}
228
229pub(crate) struct TgtOffloadEntry {
230 }
240
241impl TgtOffloadEntry {
242 pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
243 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
244 let tptr = cx.type_ptr();
245 let ti64 = cx.type_i64();
246 let ti32 = cx.type_i32();
247 let ti16 = cx.type_i16();
248 let entry_elements = <[_]>::into_vec(::alloc::boxed::box_new([ti64, ti16, ti16, ti32, tptr, tptr,
ti64, ti64, tptr]))vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
251 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
252 offload_entry_ty
253 }
254
255 fn new<'ll>(
256 cx: &CodegenCx<'ll, '_>,
257 region_id: &'ll Value,
258 llglobal: &'ll Value,
259 ) -> [&'ll Value; 9] {
260 let reserved = cx.get_const_i64(0);
261 let version = cx.get_const_i16(1);
262 let kind = cx.get_const_i16(1);
263 let flags = cx.get_const_i32(0);
264 let size = cx.get_const_i64(0);
265 let data = cx.get_const_i64(0);
266 let aux_addr = cx.const_null(cx.type_ptr());
267 [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
268 }
269}
270
271struct KernelArgsTy {
273 }
295
296impl KernelArgsTy {
297 const OFFLOAD_VERSION: u64 = 3;
298 const FLAGS: u64 = 0;
299 const TRIPCOUNT: u64 = 0;
300 fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
301 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
302 let tptr = cx.type_ptr();
303 let ti64 = cx.type_i64();
304 let ti32 = cx.type_i32();
305 let tarr = cx.type_array(ti32, 3);
306
307 let kernel_elements =
308 <[_]>::into_vec(::alloc::boxed::box_new([ti32, ti32, tptr, tptr, tptr, tptr,
tptr, tptr, ti64, ti64, tarr, tarr, ti32]))vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
309
310 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
311 kernel_arguments_ty
312 }
313
314 fn new<'ll, 'tcx>(
315 cx: &CodegenCx<'ll, 'tcx>,
316 num_args: u64,
317 memtransfer_types: &'ll Value,
318 geps: [&'ll Value; 3],
319 workgroup_dims: &'ll Value,
320 thread_dims: &'ll Value,
321 ) -> [(Align, &'ll Value); 13] {
322 let four = Align::from_bytes(4).expect("4 Byte alignment should work");
323 let eight = Align::EIGHT;
324
325 [
326 (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
327 (four, cx.get_const_i32(num_args)),
328 (eight, geps[0]),
329 (eight, geps[1]),
330 (eight, geps[2]),
331 (eight, memtransfer_types),
332 (eight, cx.const_null(cx.type_ptr())), (eight, cx.const_null(cx.type_ptr())), (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
336 (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
337 (four, workgroup_dims),
338 (four, thread_dims),
339 (four, cx.get_const_i32(0)),
340 ]
341 }
342}
343
344#[derive(#[automatically_derived]
impl<'ll> ::core::marker::Copy for OffloadKernelGlobals<'ll> { }Copy, #[automatically_derived]
impl<'ll> ::core::clone::Clone for OffloadKernelGlobals<'ll> {
#[inline]
fn clone(&self) -> OffloadKernelGlobals<'ll> {
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
*self
}
}Clone)]
346pub(crate) struct OffloadKernelGlobals<'ll> {
347 pub offload_sizes: &'ll llvm::Value,
348 pub memtransfer_begin: &'ll llvm::Value,
349 pub memtransfer_kernel: &'ll llvm::Value,
350 pub memtransfer_end: &'ll llvm::Value,
351 pub region_id: &'ll llvm::Value,
352}
353
354fn gen_tgt_data_mappers<'ll>(
355 cx: &CodegenCx<'ll, '_>,
356) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
357 let tptr = cx.type_ptr();
358 let ti64 = cx.type_i64();
359 let ti32 = cx.type_i32();
360
361 let args = <[_]>::into_vec(::alloc::boxed::box_new([tptr, ti64, ti32, tptr, tptr, tptr,
tptr, tptr, tptr]))vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
362 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
363 let mapper_begin = "__tgt_target_data_begin_mapper";
364 let mapper_update = "__tgt_target_data_update_mapper";
365 let mapper_end = "__tgt_target_data_end_mapper";
366 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
367 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
368 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
369
370 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
371 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
372 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
373 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
374
375 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
376}
377
378fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
379 let ti64 = cx.type_i64();
380 let mut size_val = Vec::with_capacity(vals.len());
381 for &val in vals {
382 size_val.push(cx.get_const_i64(val));
383 }
384 let initializer = cx.const_array(ti64, &size_val);
385 add_unnamed_global(cx, name, initializer, PrivateLinkage)
386}
387
388pub(crate) fn add_unnamed_global<'ll>(
389 cx: &SimpleCx<'ll>,
390 name: &str,
391 initializer: &'ll llvm::Value,
392 l: Linkage,
393) -> &'ll llvm::Value {
394 let llglobal = add_global(cx, name, initializer, l);
395 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
396 llglobal
397}
398
399pub(crate) fn add_global<'ll>(
400 cx: &SimpleCx<'ll>,
401 name: &str,
402 initializer: &'ll llvm::Value,
403 l: Linkage,
404) -> &'ll llvm::Value {
405 let c_name = CString::new(name).unwrap();
406 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
407 llvm::set_global_constant(llglobal, true);
408 llvm::set_linkage(llglobal, l);
409 llvm::set_initializer(llglobal, initializer);
410 llglobal
411}
412
413pub(crate) fn gen_define_handling<'ll>(
417 cx: &CodegenCx<'ll, '_>,
418 metadata: &[OffloadMetadata],
419 symbol: String,
420 offload_globals: &OffloadGlobals<'ll>,
421) -> OffloadKernelGlobals<'ll> {
422 if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
423 return *entry;
424 }
425
426 let offload_entry_ty = offload_globals.offload_entry_ty;
427
428 let (sizes, transfer): (Vec<_>, Vec<_>) =
429 metadata.iter().map(|m| (m.payload_size, m.mode)).unzip();
430 let handled_mappings = MappingFlags::TO
436 | MappingFlags::FROM
437 | MappingFlags::TARGET_PARAM
438 | MappingFlags::LITERAL
439 | MappingFlags::IMPLICIT;
440 for arg in &transfer {
441 if true {
if !!arg.contains_unknown_bits() {
::core::panicking::panic("assertion failed: !arg.contains_unknown_bits()")
};
};debug_assert!(!arg.contains_unknown_bits());
442 if true {
if !handled_mappings.contains(*arg) {
::core::panicking::panic("assertion failed: handled_mappings.contains(*arg)")
};
};debug_assert!(handled_mappings.contains(*arg));
443 }
444
445 let valid_begin_mappings = MappingFlags::TO | MappingFlags::LITERAL | MappingFlags::IMPLICIT;
446 let transfer_to: Vec<u64> =
447 transfer.iter().map(|m| m.intersection(valid_begin_mappings).bits()).collect();
448 let transfer_from: Vec<u64> =
449 transfer.iter().map(|m| m.intersection(MappingFlags::FROM).bits()).collect();
450 let transfer_kernel = ::alloc::vec::from_elem(MappingFlags::TARGET_PARAM.bits(), transfer_to.len())vec![MappingFlags::TARGET_PARAM.bits(); transfer_to.len()];
452
453 let offload_sizes = add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_sizes.{0}", symbol))
})format!(".offload_sizes.{symbol}"), &sizes);
454 let memtransfer_begin =
455 add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_maptypes.{0}.begin",
symbol))
})format!(".offload_maptypes.{symbol}.begin"), &transfer_to);
456 let memtransfer_kernel =
457 add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_maptypes.{0}.kernel",
symbol))
})format!(".offload_maptypes.{symbol}.kernel"), &transfer_kernel);
458 let memtransfer_end =
459 add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_maptypes.{0}.end",
symbol))
})format!(".offload_maptypes.{symbol}.end"), &transfer_from);
460
461 let name = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".{0}.region_id", symbol))
})format!(".{symbol}.region_id");
465 let initializer = cx.get_const_i8(0);
466 let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
467
468 let c_entry_name = CString::new(symbol.clone()).unwrap();
469 let c_val = c_entry_name.as_bytes_with_nul();
470 let offload_entry_name = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offloading.entry_name.{0}",
symbol))
})format!(".offloading.entry_name.{symbol}");
471
472 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
473 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
474 llvm::set_alignment(llglobal, Align::ONE);
475 llvm::set_section(llglobal, c".llvm.rodata.offloading");
476
477 let name = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offloading.entry.{0}", symbol))
})format!(".offloading.entry.{symbol}");
478
479 let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
481
482 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
483 let c_name = CString::new(name).unwrap();
484 let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
485 llvm::set_global_constant(offload_entry, true);
486 llvm::set_linkage(offload_entry, WeakAnyLinkage);
487 llvm::set_initializer(offload_entry, initializer);
488 llvm::set_alignment(offload_entry, Align::EIGHT);
489 let c_section_name = CString::new("llvm_offload_entries").unwrap();
490 llvm::set_section(offload_entry, &c_section_name);
491
492 cx.add_compiler_used_global(offload_entry);
493
494 let result = OffloadKernelGlobals {
495 offload_sizes,
496 memtransfer_begin,
497 memtransfer_kernel,
498 memtransfer_end,
499 region_id,
500 };
501
502 cx.add_compiler_used_global(result.offload_sizes);
504
505 cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
506
507 result
508}
509
510fn declare_offload_fn<'ll>(
511 cx: &CodegenCx<'ll, '_>,
512 name: &str,
513 ty: &'ll llvm::Type,
514) -> &'ll llvm::Value {
515 crate::declare::declare_simple_fn(
516 cx,
517 name,
518 llvm::CallConv::CCallConv,
519 llvm::UnnamedAddr::No,
520 llvm::Visibility::Default,
521 ty,
522 )
523}
524
525pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
526 match cx.type_kind(ty) {
527 TypeKind::Half
528 | TypeKind::Float
529 | TypeKind::Double
530 | TypeKind::X86_FP80
531 | TypeKind::FP128
532 | TypeKind::PPC_FP128 => cx.float_width(ty) as u64,
533 TypeKind::Integer => cx.int_width(ty),
534 other => ::rustc_middle::util::bug::bug_fmt(format_args!("scalar_width was called on a non scalar type {0:?}",
other))bug!("scalar_width was called on a non scalar type {other:?}"),
535 }
536}
537
538pub(crate) fn gen_call_handling<'ll, 'tcx>(
557 builder: &mut Builder<'_, 'll, 'tcx>,
558 offload_data: &OffloadKernelGlobals<'ll>,
559 args: &[&'ll Value],
560 types: &[&Type],
561 metadata: &[OffloadMetadata],
562 offload_globals: &OffloadGlobals<'ll>,
563 offload_dims: &OffloadKernelDims<'ll>,
564) {
565 let cx = builder.cx;
566 let OffloadKernelGlobals {
567 memtransfer_begin,
568 memtransfer_kernel,
569 memtransfer_end,
570 region_id,
571 ..
572 } = offload_data;
573 let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
574 offload_dims;
575
576 let tgt_decl = offload_globals.launcher_fn;
577 let tgt_target_kernel_ty = offload_globals.launcher_ty;
578
579 let tgt_kernel_decl = offload_globals.kernel_args_ty;
580 let begin_mapper_decl = offload_globals.begin_mapper;
581 let end_mapper_decl = offload_globals.end_mapper;
582 let fn_ty = offload_globals.mapper_fn_ty;
583
584 let num_args = types.len() as u64;
585 let bb = builder.llbb();
586
587 unsafe {
589 llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
590 }
591
592 let ty = cx.type_array(cx.type_ptr(), num_args);
593 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
595 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
597 let ty2 = cx.type_array(cx.type_i64(), num_args);
599 let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
600
601 let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
603
604 unsafe {
606 llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
607 }
608
609 let mut vals = ::alloc::vec::Vec::new()vec![];
611 let mut geps = ::alloc::vec::Vec::new()vec![];
612 let i32_0 = cx.get_const_i32(0);
613 for &v in args {
614 let ty = cx.val_ty(v);
615 let ty_kind = cx.type_kind(ty);
616 let (base_val, gep_base) = match ty_kind {
617 TypeKind::Pointer => (v, v),
618 TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
619 let num_bits = scalar_width(cx, ty);
621
622 let bb = builder.llbb();
623 unsafe {
624 llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, builder.llfn());
625 }
626 let addr = builder.direct_alloca(cx.type_i64(), Align::EIGHT, "addr");
627 unsafe {
628 llvm::LLVMPositionBuilderAtEnd(builder.llbuilder, bb);
629 }
630
631 let cast = builder.bitcast(v, cx.type_ix(num_bits));
632 let value = builder.zext(cast, cx.type_i64());
633 builder.store(value, addr, Align::EIGHT);
634 (value, addr)
635 }
636 other => ::rustc_middle::util::bug::bug_fmt(format_args!("offload does not support {0:?}",
other))bug!("offload does not support {other:?}"),
637 };
638
639 let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]);
640
641 vals.push(base_val);
642 geps.push(gep);
643 }
644
645 for i in 0..num_args {
646 let idx = cx.get_const_i32(i);
647 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
648 builder.store(vals[i as usize], gep1, Align::EIGHT);
649 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
650 builder.store(geps[i as usize], gep2, Align::EIGHT);
651 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
652 builder.store(cx.get_const_i64(metadata[i as usize].payload_size), gep3, Align::EIGHT);
654 }
655
656 fn get_geps<'ll, 'tcx>(
659 builder: &mut Builder<'_, 'll, 'tcx>,
660 ty: &'ll Type,
661 ty2: &'ll Type,
662 a1: &'ll Value,
663 a2: &'ll Value,
664 a4: &'ll Value,
665 ) -> [&'ll Value; 3] {
666 let cx = builder.cx;
667 let i32_0 = cx.get_const_i32(0);
668
669 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
670 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
671 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
672 [gep1, gep2, gep3]
673 }
674
675 fn generate_mapper_call<'ll, 'tcx>(
676 builder: &mut Builder<'_, 'll, 'tcx>,
677 geps: [&'ll Value; 3],
678 o_type: &'ll Value,
679 fn_to_call: &'ll Value,
680 fn_ty: &'ll Type,
681 num_args: u64,
682 s_ident_t: &'ll Value,
683 ) {
684 let cx = builder.cx;
685 let nullptr = cx.const_null(cx.type_ptr());
686 let i64_max = cx.get_const_i64(u64::MAX);
687 let num_args = cx.get_const_i32(num_args);
688 let args =
689 <[_]>::into_vec(::alloc::boxed::box_new([s_ident_t, i64_max, num_args,
geps[0], geps[1], geps[2], o_type, nullptr, nullptr]))vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
690 builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
691 }
692
693 let s_ident_t = offload_globals.ident_t_global;
695 let geps = get_geps(builder, ty, ty2, a1, a2, a4);
696 generate_mapper_call(
697 builder,
698 geps,
699 memtransfer_begin,
700 begin_mapper_decl,
701 fn_ty,
702 num_args,
703 s_ident_t,
704 );
705 let values =
706 KernelArgsTy::new(&cx, num_args, memtransfer_kernel, geps, workgroup_dims, thread_dims);
707
708 for (i, value) in values.iter().enumerate() {
711 let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
712 builder.store(value.1, ptr, value.0);
713 }
714
715 let args = <[_]>::into_vec(::alloc::boxed::box_new([s_ident_t,
cx.get_const_i64(u64::MAX), num_workgroups, threads_per_block,
region_id, a5]))vec![
716 s_ident_t,
717 cx.get_const_i64(u64::MAX), num_workgroups,
720 threads_per_block,
721 region_id,
722 a5,
723 ];
724 builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
725 let geps = get_geps(builder, ty, ty2, a1, a2, a4);
729 generate_mapper_call(
730 builder,
731 geps,
732 memtransfer_end,
733 end_mapper_decl,
734 fn_ty,
735 num_args,
736 s_ident_t,
737 );
738}