1use std::ffi::CString;
2
3use bitflags::Flags;
4use llvm::Linkage::*;
5use rustc_abi::Align;
6use rustc_codegen_ssa::MemFlags;
7use rustc_codegen_ssa::common::TypeKind;
8use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
9use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
10use rustc_middle::bug;
11use rustc_middle::ty::offload_meta::{MappingFlags, OffloadMetadata, OffloadSize};
12
13use crate::builder::Builder;
14use crate::common::CodegenCx;
15use crate::llvm::AttributePlace::Function;
16use crate::llvm::{self, Linkage, Type, Value};
17use crate::{SimpleCx, attributes};
18
19pub(crate) struct OffloadGlobals<'ll> {
21 pub launcher_fn: &'ll llvm::Value,
22 pub launcher_ty: &'ll llvm::Type,
23
24 pub kernel_args_ty: &'ll llvm::Type,
25
26 pub offload_entry_ty: &'ll llvm::Type,
27
28 pub begin_mapper: &'ll llvm::Value,
29 pub end_mapper: &'ll llvm::Value,
30 pub mapper_fn_ty: &'ll llvm::Type,
31
32 pub ident_t_global: &'ll llvm::Value,
33}
34
35impl<'ll> OffloadGlobals<'ll> {
36 pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
37 let (launcher_fn, launcher_ty) = generate_launcher(cx);
38 let kernel_args_ty = KernelArgsTy::new_decl(cx);
39 let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
40 let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
41 let ident_t_global = generate_at_one(cx);
42
43 llvm::add_module_flag_u32(cx.llmod(), llvm::ModuleFlagMergeBehavior::Max, "openmp", 51);
46
47 OffloadGlobals {
48 launcher_fn,
49 launcher_ty,
50 kernel_args_ty,
51 offload_entry_ty,
52 begin_mapper,
53 end_mapper,
54 mapper_fn_ty,
55 ident_t_global,
56 }
57 }
58}
59
60pub(crate) fn register_offload<'ll>(cx: &CodegenCx<'ll, '_>) {
67 let register_lib_name = "__tgt_register_lib";
70 if cx.get_function(register_lib_name).is_some() {
71 return;
72 }
73
74 let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
75 let register_lib = declare_offload_fn(&cx, register_lib_name, reg_lib_decl);
76 let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
77
78 let ptr_null = cx.const_null(cx.type_ptr());
79 let const_struct = cx.const_struct(&[cx.get_const_i32(0), ptr_null, ptr_null, ptr_null], false);
80 let omp_descriptor =
81 add_global(cx, ".omp_offloading.descriptor", const_struct, InternalLinkage);
82 let atexit = cx.type_func(&[cx.type_ptr()], cx.type_i32());
86 let atexit_fn = declare_offload_fn(cx, "atexit", atexit);
87
88 let init_ty = cx.type_func(&[], cx.type_void());
91 let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
92
93 let desc_ty = cx.type_func(&[], cx.type_void());
94 let reg_name = ".omp_offloading.descriptor_reg";
95 let unreg_name = ".omp_offloading.descriptor_unreg";
96 let desc_reg_fn = declare_offload_fn(cx, reg_name, desc_ty);
97 let desc_unreg_fn = declare_offload_fn(cx, unreg_name, desc_ty);
98 llvm::set_linkage(desc_reg_fn, InternalLinkage);
99 llvm::set_linkage(desc_unreg_fn, InternalLinkage);
100 llvm::set_section(desc_reg_fn, c".text.startup");
101 llvm::set_section(desc_unreg_fn, c".text.startup");
102
103 let bb = Builder::append_block(cx, desc_reg_fn, "entry");
111 let mut a = Builder::build(cx, bb);
112 a.call(reg_lib_decl, None, None, register_lib, &[omp_descriptor], None, None);
113 a.call(init_ty, None, None, init_rtls, &[], None, None);
114 a.call(atexit, None, None, atexit_fn, &[desc_unreg_fn], None, None);
115 a.ret_void();
116
117 let bb = Builder::append_block(cx, desc_unreg_fn, "entry");
123 let mut a = Builder::build(cx, bb);
124 a.call(reg_lib_decl, None, None, unregister_lib, &[omp_descriptor], None, None);
125 a.ret_void();
126
127 let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[cx.get_const_i32(101), desc_reg_fn, ptr_null]))vec![cx.get_const_i32(101), desc_reg_fn, ptr_null];
129 let const_struct = cx.const_struct(&args, false);
130 let arr = cx.const_array(cx.val_ty(const_struct), &[const_struct]);
131 add_global(cx, "llvm.global_ctors", arr, AppendingLinkage);
132}
133
134pub(crate) struct OffloadKernelDims<'ll> {
135 num_workgroups: &'ll Value,
136 threads_per_block: &'ll Value,
137 workgroup_dims: &'ll Value,
138 thread_dims: &'ll Value,
139}
140
141impl<'ll> OffloadKernelDims<'ll> {
142 pub(crate) fn from_operands<'tcx>(
143 builder: &mut Builder<'_, 'll, 'tcx>,
144 workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
145 thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
146 ) -> Self {
147 let cx = builder.cx;
148 let arr_ty = cx.type_array(cx.type_i32(), 3);
149 let four = Align::from_bytes(4).unwrap();
150
151 let OperandValue::Ref(place) = workgroup_op.val else {
152 ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
153 };
154 let workgroup_val = builder.load(arr_ty, place.llval, four);
155
156 let OperandValue::Ref(place) = thread_op.val else {
157 ::rustc_middle::util::bug::bug_fmt(format_args!("expected array operand by reference"));bug!("expected array operand by reference");
158 };
159 let thread_val = builder.load(arr_ty, place.llval, four);
160
161 fn mul_dim3<'ll, 'tcx>(
162 builder: &mut Builder<'_, 'll, 'tcx>,
163 arr: &'ll Value,
164 ) -> &'ll Value {
165 let x = builder.extract_value(arr, 0);
166 let y = builder.extract_value(arr, 1);
167 let z = builder.extract_value(arr, 2);
168
169 let xy = builder.mul(x, y);
170 builder.mul(xy, z)
171 }
172
173 let num_workgroups = mul_dim3(builder, workgroup_val);
174 let threads_per_block = mul_dim3(builder, thread_val);
175
176 OffloadKernelDims {
177 workgroup_dims: workgroup_val,
178 thread_dims: thread_val,
179 num_workgroups,
180 threads_per_block,
181 }
182 }
183}
184
185fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
188 let tptr = cx.type_ptr();
189 let ti64 = cx.type_i64();
190 let ti32 = cx.type_i32();
191 let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[tptr, ti64, ti32, ti32, tptr, tptr]))vec![tptr, ti64, ti32, ti32, tptr, tptr];
192 let tgt_fn_ty = cx.type_func(&args, ti32);
193 let name = "__tgt_target_kernel";
194 let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
195 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
196 attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
197 (tgt_decl, tgt_fn_ty)
198}
199
200pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
206 let unknown_txt = ";unknown;unknown;0;0;;";
207 let c_entry_name = CString::new(unknown_txt).unwrap();
208 let c_val = c_entry_name.as_bytes_with_nul();
209 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
210 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
211 llvm::set_alignment(at_zero, Align::ONE);
212
213 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
215 let struct_elems = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[cx.get_const_i32(0), cx.get_const_i32(2), cx.get_const_i32(0),
cx.get_const_i32(22), at_zero]))vec![
216 cx.get_const_i32(0),
217 cx.get_const_i32(2),
218 cx.get_const_i32(0),
219 cx.get_const_i32(22),
220 at_zero,
221 ];
222 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
223 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
224 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
225 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
226 llvm::set_alignment(at_one, Align::EIGHT);
227 at_one
228}
229
230pub(crate) struct TgtOffloadEntry {
231 }
241
242impl TgtOffloadEntry {
243 pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
244 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
245 let tptr = cx.type_ptr();
246 let ti64 = cx.type_i64();
247 let ti32 = cx.type_i32();
248 let ti16 = cx.type_i16();
249 let entry_elements = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr]))vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
252 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
253 offload_entry_ty
254 }
255
256 fn new<'ll>(
257 cx: &CodegenCx<'ll, '_>,
258 region_id: &'ll Value,
259 llglobal: &'ll Value,
260 ) -> [&'ll Value; 9] {
261 let reserved = cx.get_const_i64(0);
262 let version = cx.get_const_i16(1);
263 let kind = cx.get_const_i16(1);
264 let flags = cx.get_const_i32(0);
265 let size = cx.get_const_i64(0);
266 let data = cx.get_const_i64(0);
267 let aux_addr = cx.const_null(cx.type_ptr());
268 [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
269 }
270}
271
272struct KernelArgsTy {
274 }
296
297impl KernelArgsTy {
298 const OFFLOAD_VERSION: u64 = 3;
299 const FLAGS: u64 = 0;
300 const TRIPCOUNT: u64 = 0;
301 fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
302 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
303 let tptr = cx.type_ptr();
304 let ti64 = cx.type_i64();
305 let ti32 = cx.type_i32();
306 let tarr = cx.type_array(ti32, 3);
307
308 let kernel_elements =
309 ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr,
tarr, ti32]))vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
310
311 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
312 kernel_arguments_ty
313 }
314
315 fn new<'ll, 'tcx>(
316 cx: &CodegenCx<'ll, 'tcx>,
317 num_args: u64,
318 memtransfer_types: &'ll Value,
319 geps: [&'ll Value; 3],
320 workgroup_dims: &'ll Value,
321 thread_dims: &'ll Value,
322 dyn_cache: &'ll Value,
323 ) -> [(Align, &'ll str, &'ll Value); 13] {
324 let four = Align::from_bytes(4).expect("4 Byte alignment should work");
325 let eight = Align::EIGHT;
326
327 [
328 (four, "Version", cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
329 (four, "NumArgs", cx.get_const_i32(num_args)),
330 (eight, "ArgBasePtrs", geps[0]),
331 (eight, "ArgPtrs", geps[1]),
332 (eight, "ArgSizes", geps[2]),
333 (eight, "ArgTypes", memtransfer_types),
334 (eight, "ArgNames", cx.const_null(cx.type_ptr())), (eight, "ArgMappers", cx.const_null(cx.type_ptr())), (eight, "Tripcount", cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
338 (eight, "Flags", cx.get_const_i64(KernelArgsTy::FLAGS)),
339 (four, "NumTeams", workgroup_dims),
340 (four, "ThreadLimit", thread_dims),
341 (four, "DynCGroupMem", dyn_cache),
342 ]
343 }
344}
345
346#[derive(#[automatically_derived]
impl<'ll> ::core::marker::Copy for OffloadKernelGlobals<'ll> { }Copy, #[automatically_derived]
impl<'ll> ::core::clone::Clone for OffloadKernelGlobals<'ll> {
#[inline]
fn clone(&self) -> OffloadKernelGlobals<'ll> {
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
let _: ::core::clone::AssertParamIsClone<&'ll llvm::Value>;
*self
}
}Clone)]
348pub(crate) struct OffloadKernelGlobals<'ll> {
349 pub offload_sizes: &'ll llvm::Value,
350 pub memtransfer_begin: &'ll llvm::Value,
351 pub memtransfer_kernel: &'ll llvm::Value,
352 pub memtransfer_end: &'ll llvm::Value,
353 pub region_id: &'ll llvm::Value,
354}
355
356fn gen_tgt_data_mappers<'ll>(
357 cx: &CodegenCx<'ll, '_>,
358) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
359 let tptr = cx.type_ptr();
360 let ti64 = cx.type_i64();
361 let ti32 = cx.type_i32();
362
363 let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr]))vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
364 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
365 let mapper_begin = "__tgt_target_data_begin_mapper";
366 let mapper_update = "__tgt_target_data_update_mapper";
367 let mapper_end = "__tgt_target_data_end_mapper";
368 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
369 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
370 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
371
372 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
373 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
374 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
375 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
376
377 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
378}
379
380fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
381 let ti64 = cx.type_i64();
382 let mut size_val = Vec::with_capacity(vals.len());
383 for &val in vals {
384 size_val.push(cx.get_const_i64(val));
385 }
386 let initializer = cx.const_array(ti64, &size_val);
387 add_unnamed_global(cx, name, initializer, PrivateLinkage)
388}
389
390pub(crate) fn add_unnamed_global<'ll>(
391 cx: &SimpleCx<'ll>,
392 name: &str,
393 initializer: &'ll llvm::Value,
394 l: Linkage,
395) -> &'ll llvm::Value {
396 let llglobal = add_global(cx, name, initializer, l);
397 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
398 llglobal
399}
400
401pub(crate) fn add_global<'ll>(
402 cx: &SimpleCx<'ll>,
403 name: &str,
404 initializer: &'ll llvm::Value,
405 l: Linkage,
406) -> &'ll llvm::Value {
407 let c_name = CString::new(name).unwrap();
408 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
409 llvm::set_global_constant(llglobal, true);
410 llvm::set_linkage(llglobal, l);
411 llvm::set_initializer(llglobal, initializer);
412 llglobal
413}
414
415pub(crate) fn gen_define_handling<'ll>(
419 cx: &CodegenCx<'ll, '_>,
420 metadata: &[OffloadMetadata],
421 symbol: String,
422 offload_globals: &OffloadGlobals<'ll>,
423) -> OffloadKernelGlobals<'ll> {
424 if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
425 return *entry;
426 }
427
428 let offload_entry_ty = offload_globals.offload_entry_ty;
429
430 let (sizes, transfer): (Vec<_>, Vec<_>) =
431 metadata.iter().map(|m| (m.payload_size, m.mode)).unzip();
432 let handled_mappings = MappingFlags::TO
438 | MappingFlags::FROM
439 | MappingFlags::TARGET_PARAM
440 | MappingFlags::LITERAL
441 | MappingFlags::IMPLICIT;
442 for arg in &transfer {
443 if true {
if !!arg.contains_unknown_bits() {
::core::panicking::panic("assertion failed: !arg.contains_unknown_bits()")
};
};debug_assert!(!arg.contains_unknown_bits());
444 if true {
if !handled_mappings.contains(*arg) {
::core::panicking::panic("assertion failed: handled_mappings.contains(*arg)")
};
};debug_assert!(handled_mappings.contains(*arg));
445 }
446
447 let valid_begin_mappings = MappingFlags::TO | MappingFlags::LITERAL | MappingFlags::IMPLICIT;
448 let transfer_to: Vec<u64> =
449 transfer.iter().map(|m| m.intersection(valid_begin_mappings).bits()).collect();
450 let transfer_from: Vec<u64> =
451 transfer.iter().map(|m| m.intersection(MappingFlags::FROM).bits()).collect();
452 let valid_kernel_mappings = MappingFlags::LITERAL | MappingFlags::IMPLICIT;
453 let transfer_kernel: Vec<u64> = transfer
455 .iter()
456 .map(|m| (m.intersection(valid_kernel_mappings) | MappingFlags::TARGET_PARAM).bits())
457 .collect();
458
459 let actual_sizes = sizes
460 .iter()
461 .map(|s| match s {
462 OffloadSize::Static(sz) => *sz,
463 _ => 0,
465 })
466 .collect::<Vec<_>>();
467 let offload_sizes =
468 add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_sizes.{0}", symbol))
})format!(".offload_sizes.{symbol}"), &actual_sizes);
469 let memtransfer_begin =
470 add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_maptypes.{0}.begin",
symbol))
})format!(".offload_maptypes.{symbol}.begin"), &transfer_to);
471 let memtransfer_kernel =
472 add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_maptypes.{0}.kernel",
symbol))
})format!(".offload_maptypes.{symbol}.kernel"), &transfer_kernel);
473 let memtransfer_end =
474 add_priv_unnamed_arr(&cx, &::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offload_maptypes.{0}.end",
symbol))
})format!(".offload_maptypes.{symbol}.end"), &transfer_from);
475
476 let name = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".{0}.region_id", symbol))
})format!(".{symbol}.region_id");
480 let initializer = cx.get_const_i8(0);
481 let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
482
483 let c_entry_name = CString::new(symbol.clone()).unwrap();
484 let c_val = c_entry_name.as_bytes_with_nul();
485 let offload_entry_name = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offloading.entry_name.{0}",
symbol))
})format!(".offloading.entry_name.{symbol}");
486
487 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
488 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
489 llvm::set_alignment(llglobal, Align::ONE);
490 llvm::set_section(llglobal, c".llvm.rodata.offloading");
491
492 let name = ::alloc::__export::must_use({
::alloc::fmt::format(format_args!(".offloading.entry.{0}", symbol))
})format!(".offloading.entry.{symbol}");
493
494 let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
496
497 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
498 let c_name = CString::new(name).unwrap();
499 let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
500 llvm::set_global_constant(offload_entry, true);
501 llvm::set_linkage(offload_entry, WeakAnyLinkage);
502 llvm::set_initializer(offload_entry, initializer);
503 llvm::set_alignment(offload_entry, Align::EIGHT);
504 let c_section_name = CString::new("llvm_offload_entries").unwrap();
505 llvm::set_section(offload_entry, &c_section_name);
506
507 cx.add_compiler_used_global(offload_entry);
508
509 let result = OffloadKernelGlobals {
510 offload_sizes,
511 memtransfer_begin,
512 memtransfer_kernel,
513 memtransfer_end,
514 region_id,
515 };
516
517 cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
518
519 result
520}
521
522fn declare_offload_fn<'ll>(
523 cx: &CodegenCx<'ll, '_>,
524 name: &str,
525 ty: &'ll llvm::Type,
526) -> &'ll llvm::Value {
527 crate::declare::declare_simple_fn(
528 cx,
529 name,
530 llvm::CallConv::CCallConv,
531 llvm::UnnamedAddr::No,
532 llvm::Visibility::Default,
533 ty,
534 )
535}
536
537pub(crate) fn scalar_width<'ll>(cx: &'ll SimpleCx<'_>, ty: &'ll Type) -> u64 {
538 match cx.type_kind(ty) {
539 TypeKind::Half
540 | TypeKind::Float
541 | TypeKind::Double
542 | TypeKind::X86_FP80
543 | TypeKind::FP128
544 | TypeKind::PPC_FP128 => cx.float_width(ty) as u64,
545 TypeKind::Integer => cx.int_width(ty),
546 other => ::rustc_middle::util::bug::bug_fmt(format_args!("scalar_width was called on a non scalar type {0:?}",
other))bug!("scalar_width was called on a non scalar type {other:?}"),
547 }
548}
549
550fn get_runtime_size<'ll, 'tcx>(
551 builder: &mut Builder<'_, 'll, 'tcx>,
552 args: &[&'ll Value],
553 index: usize,
554 meta: &OffloadMetadata,
555) -> &'ll Value {
556 match meta.payload_size {
557 OffloadSize::Slice { element_size } => {
558 let length_idx = index + 1;
559 let length = args[length_idx];
560 let length_i64 = builder.intcast(length, builder.cx.type_i64(), false);
561 builder.mul(length_i64, builder.cx.get_const_i64(element_size))
562 }
563 _ => ::rustc_middle::util::bug::bug_fmt(format_args!("unexpected offload size {0:?}",
meta.payload_size))bug!("unexpected offload size {:?}", meta.payload_size),
564 }
565}
566
567pub(crate) fn gen_call_handling<'ll, 'tcx>(
586 builder: &mut Builder<'_, 'll, 'tcx>,
587 offload_data: &OffloadKernelGlobals<'ll>,
588 args: &[&'ll Value],
589 types: &[&Type],
590 metadata: &[OffloadMetadata],
591 offload_globals: &OffloadGlobals<'ll>,
592 offload_dims: &OffloadKernelDims<'ll>,
593 dyn_cache: &'ll Value,
594) {
595 let cx = builder.cx;
596 let OffloadKernelGlobals {
597 offload_sizes,
598 memtransfer_begin,
599 memtransfer_kernel,
600 memtransfer_end,
601 region_id,
602 } = offload_data;
603 let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
604 offload_dims;
605
606 let has_dynamic = metadata.iter().any(|m| !#[allow(non_exhaustive_omitted_patterns)] match m.payload_size {
OffloadSize::Static(_) => true,
_ => false,
}matches!(m.payload_size, OffloadSize::Static(_)));
607
608 let tgt_decl = offload_globals.launcher_fn;
609 let tgt_target_kernel_ty = offload_globals.launcher_ty;
610
611 let tgt_kernel_decl = offload_globals.kernel_args_ty;
612 let begin_mapper_decl = offload_globals.begin_mapper;
613 let end_mapper_decl = offload_globals.end_mapper;
614 let fn_ty = offload_globals.mapper_fn_ty;
615
616 let num_args = types.len() as u64;
617 let bb = builder.llbb();
618
619 unsafe {
621 llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
622 }
623
624 let ty = cx.type_array(cx.type_ptr(), num_args);
625 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
627 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
629 let ty2 = cx.type_array(cx.type_i64(), num_args);
631
632 let a4 = if has_dynamic {
633 let alloc = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
634
635 builder.memcpy(
636 alloc,
637 Align::EIGHT,
638 offload_sizes,
639 Align::EIGHT,
640 cx.get_const_i64(8 * args.len() as u64),
641 MemFlags::empty(),
642 None,
643 );
644
645 alloc
646 } else {
647 offload_sizes
648 };
649
650 let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
652
653 unsafe {
655 llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
656 }
657
658 let mut vals = ::alloc::vec::Vec::new()vec![];
660 let mut geps = ::alloc::vec::Vec::new()vec![];
661 let i32_0 = cx.get_const_i32(0);
662 for &v in args {
663 let ty = cx.val_ty(v);
664 let ty_kind = cx.type_kind(ty);
665 let (base_val, gep_base) = match ty_kind {
666 TypeKind::Pointer => (v, v),
667 TypeKind::Half | TypeKind::Float | TypeKind::Double | TypeKind::Integer => {
668 let num_bits = scalar_width(cx, ty);
670
671 let bb = builder.llbb();
672 unsafe {
673 llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, builder.llfn());
674 }
675 let addr = builder.direct_alloca(cx.type_i64(), Align::EIGHT, "addr");
676 unsafe {
677 llvm::LLVMPositionBuilderAtEnd(builder.llbuilder, bb);
678 }
679
680 let cast = builder.bitcast(v, cx.type_ix(num_bits));
681 let value = builder.zext(cast, cx.type_i64());
682 builder.store(value, addr, Align::EIGHT);
683 (value, addr)
684 }
685 other => ::rustc_middle::util::bug::bug_fmt(format_args!("offload does not support {0:?}",
other))bug!("offload does not support {other:?}"),
686 };
687
688 let gep = builder.inbounds_gep(cx.type_f32(), gep_base, &[i32_0]);
689
690 vals.push(base_val);
691 geps.push(gep);
692 }
693
694 for i in 0..num_args {
695 let idx = cx.get_const_i32(i);
696 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
697 builder.store(vals[i as usize], gep1, Align::EIGHT);
698 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
699 builder.store(geps[i as usize], gep2, Align::EIGHT);
700
701 if !#[allow(non_exhaustive_omitted_patterns)] match metadata[i as
usize].payload_size {
OffloadSize::Static(_) => true,
_ => false,
}matches!(metadata[i as usize].payload_size, OffloadSize::Static(_)) {
702 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
703 let size_val = get_runtime_size(builder, args, i as usize, &metadata[i as usize]);
704 builder.store(size_val, gep3, Align::EIGHT);
705 }
706 }
707
708 fn get_geps<'ll, 'tcx>(
711 builder: &mut Builder<'_, 'll, 'tcx>,
712 ty: &'ll Type,
713 ty2: &'ll Type,
714 a1: &'ll Value,
715 a2: &'ll Value,
716 a4: &'ll Value,
717 is_dynamic: bool,
718 ) -> [&'ll Value; 3] {
719 let cx = builder.cx;
720 let i32_0 = cx.get_const_i32(0);
721
722 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
723 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
724 let gep3 = if is_dynamic { builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]) } else { a4 };
725 [gep1, gep2, gep3]
726 }
727
728 fn generate_mapper_call<'ll, 'tcx>(
729 builder: &mut Builder<'_, 'll, 'tcx>,
730 geps: [&'ll Value; 3],
731 o_type: &'ll Value,
732 fn_to_call: &'ll Value,
733 fn_ty: &'ll Type,
734 num_args: u64,
735 s_ident_t: &'ll Value,
736 ) {
737 let cx = builder.cx;
738 let nullptr = cx.const_null(cx.type_ptr());
739 let i64_max = cx.get_const_i64(u64::MAX);
740 let num_args = cx.get_const_i32(num_args);
741 let args =
742 ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type,
nullptr, nullptr]))vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
743 builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
744 }
745
746 let s_ident_t = offload_globals.ident_t_global;
748 let geps = get_geps(builder, ty, ty2, a1, a2, a4, has_dynamic);
749 generate_mapper_call(
750 builder,
751 geps,
752 memtransfer_begin,
753 begin_mapper_decl,
754 fn_ty,
755 num_args,
756 s_ident_t,
757 );
758 let values = KernelArgsTy::new(
759 &cx,
760 num_args,
761 memtransfer_kernel,
762 geps,
763 workgroup_dims,
764 thread_dims,
765 dyn_cache,
766 );
767
768 for (i, value) in values.iter().enumerate() {
771 let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
772 let name = std::ffi::CString::new(value.1).unwrap();
773 llvm::set_value_name(ptr, &name.as_bytes());
774
775 builder.store(value.2, ptr, value.0);
776 }
777
778 let args = ::alloc::boxed::box_assume_init_into_vec_unsafe(::alloc::intrinsics::write_box_via_move(::alloc::boxed::Box::new_uninit(),
[s_ident_t, cx.get_const_i64(u64::MAX), num_workgroups,
threads_per_block, region_id, a5]))vec![
779 s_ident_t,
780 cx.get_const_i64(u64::MAX), num_workgroups,
783 threads_per_block,
784 region_id,
785 a5,
786 ];
787 builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
788 let geps = get_geps(builder, ty, ty2, a1, a2, a4, has_dynamic);
792 generate_mapper_call(
793 builder,
794 geps,
795 memtransfer_end,
796 end_mapper_decl,
797 fn_ty,
798 num_args,
799 s_ident_t,
800 );
801}