1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::mir::operand::{OperandRef, OperandValue};
6use rustc_codegen_ssa::traits::{BaseTypeCodegenMethods, BuilderMethods};
7use rustc_middle::bug;
8use rustc_middle::ty::offload_meta::OffloadMetadata;
9
10use crate::builder::Builder;
11use crate::common::CodegenCx;
12use crate::llvm::AttributePlace::Function;
13use crate::llvm::{self, Linkage, Type, Value};
14use crate::{SimpleCx, attributes};
15
16pub(crate) struct OffloadGlobals<'ll> {
18 pub launcher_fn: &'ll llvm::Value,
19 pub launcher_ty: &'ll llvm::Type,
20
21 pub bin_desc: &'ll llvm::Type,
22
23 pub kernel_args_ty: &'ll llvm::Type,
24
25 pub offload_entry_ty: &'ll llvm::Type,
26
27 pub begin_mapper: &'ll llvm::Value,
28 pub end_mapper: &'ll llvm::Value,
29 pub mapper_fn_ty: &'ll llvm::Type,
30
31 pub ident_t_global: &'ll llvm::Value,
32
33 pub register_lib: &'ll llvm::Value,
34 pub unregister_lib: &'ll llvm::Value,
35 pub init_rtls: &'ll llvm::Value,
36}
37
38impl<'ll> OffloadGlobals<'ll> {
39 pub(crate) fn declare(cx: &CodegenCx<'ll, '_>) -> Self {
40 let (launcher_fn, launcher_ty) = generate_launcher(cx);
41 let kernel_args_ty = KernelArgsTy::new_decl(cx);
42 let offload_entry_ty = TgtOffloadEntry::new_decl(cx);
43 let (begin_mapper, _, end_mapper, mapper_fn_ty) = gen_tgt_data_mappers(cx);
44 let ident_t_global = generate_at_one(cx);
45
46 let tptr = cx.type_ptr();
47 let ti32 = cx.type_i32();
48 let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
49 let bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
50 cx.set_struct_body(bin_desc, &tgt_bin_desc_ty, false);
51
52 let reg_lib_decl = cx.type_func(&[cx.type_ptr()], cx.type_void());
53 let register_lib = declare_offload_fn(&cx, "__tgt_register_lib", reg_lib_decl);
54 let unregister_lib = declare_offload_fn(&cx, "__tgt_unregister_lib", reg_lib_decl);
55 let init_ty = cx.type_func(&[], cx.type_void());
56 let init_rtls = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
57
58 OffloadGlobals {
59 launcher_fn,
60 launcher_ty,
61 bin_desc,
62 kernel_args_ty,
63 offload_entry_ty,
64 begin_mapper,
65 end_mapper,
66 mapper_fn_ty,
67 ident_t_global,
68 register_lib,
69 unregister_lib,
70 init_rtls,
71 }
72 }
73}
74
75pub(crate) struct OffloadKernelDims<'ll> {
76 num_workgroups: &'ll Value,
77 threads_per_block: &'ll Value,
78 workgroup_dims: &'ll Value,
79 thread_dims: &'ll Value,
80}
81
82impl<'ll> OffloadKernelDims<'ll> {
83 pub(crate) fn from_operands<'tcx>(
84 builder: &mut Builder<'_, 'll, 'tcx>,
85 workgroup_op: &OperandRef<'tcx, &'ll llvm::Value>,
86 thread_op: &OperandRef<'tcx, &'ll llvm::Value>,
87 ) -> Self {
88 let cx = builder.cx;
89 let arr_ty = cx.type_array(cx.type_i32(), 3);
90 let four = Align::from_bytes(4).unwrap();
91
92 let OperandValue::Ref(place) = workgroup_op.val else {
93 bug!("expected array operand by reference");
94 };
95 let workgroup_val = builder.load(arr_ty, place.llval, four);
96
97 let OperandValue::Ref(place) = thread_op.val else {
98 bug!("expected array operand by reference");
99 };
100 let thread_val = builder.load(arr_ty, place.llval, four);
101
102 fn mul_dim3<'ll, 'tcx>(
103 builder: &mut Builder<'_, 'll, 'tcx>,
104 arr: &'ll Value,
105 ) -> &'ll Value {
106 let x = builder.extract_value(arr, 0);
107 let y = builder.extract_value(arr, 1);
108 let z = builder.extract_value(arr, 2);
109
110 let xy = builder.mul(x, y);
111 builder.mul(xy, z)
112 }
113
114 let num_workgroups = mul_dim3(builder, workgroup_val);
115 let threads_per_block = mul_dim3(builder, thread_val);
116
117 OffloadKernelDims {
118 workgroup_dims: workgroup_val,
119 thread_dims: thread_val,
120 num_workgroups,
121 threads_per_block,
122 }
123 }
124}
125
126fn generate_launcher<'ll>(cx: &CodegenCx<'ll, '_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
129 let tptr = cx.type_ptr();
130 let ti64 = cx.type_i64();
131 let ti32 = cx.type_i32();
132 let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
133 let tgt_fn_ty = cx.type_func(&args, ti32);
134 let name = "__tgt_target_kernel";
135 let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
136 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
137 attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
138 (tgt_decl, tgt_fn_ty)
139}
140
141pub(crate) fn generate_at_one<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Value {
147 let unknown_txt = ";unknown;unknown;0;0;;";
148 let c_entry_name = CString::new(unknown_txt).unwrap();
149 let c_val = c_entry_name.as_bytes_with_nul();
150 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
151 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
152 llvm::set_alignment(at_zero, Align::ONE);
153
154 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
156 let struct_elems = vec![
157 cx.get_const_i32(0),
158 cx.get_const_i32(2),
159 cx.get_const_i32(0),
160 cx.get_const_i32(22),
161 at_zero,
162 ];
163 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
164 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
165 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
166 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
167 llvm::set_alignment(at_one, Align::EIGHT);
168 at_one
169}
170
171pub(crate) struct TgtOffloadEntry {
172 }
182
183impl TgtOffloadEntry {
184 pub(crate) fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll llvm::Type {
185 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
186 let tptr = cx.type_ptr();
187 let ti64 = cx.type_i64();
188 let ti32 = cx.type_i32();
189 let ti16 = cx.type_i16();
190 let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
193 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
194 offload_entry_ty
195 }
196
197 fn new<'ll>(
198 cx: &CodegenCx<'ll, '_>,
199 region_id: &'ll Value,
200 llglobal: &'ll Value,
201 ) -> [&'ll Value; 9] {
202 let reserved = cx.get_const_i64(0);
203 let version = cx.get_const_i16(1);
204 let kind = cx.get_const_i16(1);
205 let flags = cx.get_const_i32(0);
206 let size = cx.get_const_i64(0);
207 let data = cx.get_const_i64(0);
208 let aux_addr = cx.const_null(cx.type_ptr());
209 [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
210 }
211}
212
213struct KernelArgsTy {
215 }
237
238impl KernelArgsTy {
239 const OFFLOAD_VERSION: u64 = 3;
240 const FLAGS: u64 = 0;
241 const TRIPCOUNT: u64 = 0;
242 fn new_decl<'ll>(cx: &CodegenCx<'ll, '_>) -> &'ll Type {
243 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
244 let tptr = cx.type_ptr();
245 let ti64 = cx.type_i64();
246 let ti32 = cx.type_i32();
247 let tarr = cx.type_array(ti32, 3);
248
249 let kernel_elements =
250 vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
251
252 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
253 kernel_arguments_ty
254 }
255
256 fn new<'ll, 'tcx>(
257 cx: &CodegenCx<'ll, 'tcx>,
258 num_args: u64,
259 memtransfer_types: &'ll Value,
260 geps: [&'ll Value; 3],
261 workgroup_dims: &'ll Value,
262 thread_dims: &'ll Value,
263 ) -> [(Align, &'ll Value); 13] {
264 let four = Align::from_bytes(4).expect("4 Byte alignment should work");
265 let eight = Align::EIGHT;
266
267 [
268 (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
269 (four, cx.get_const_i32(num_args)),
270 (eight, geps[0]),
271 (eight, geps[1]),
272 (eight, geps[2]),
273 (eight, memtransfer_types),
274 (eight, cx.const_null(cx.type_ptr())), (eight, cx.const_null(cx.type_ptr())), (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
278 (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
279 (four, workgroup_dims),
280 (four, thread_dims),
281 (four, cx.get_const_i32(0)),
282 ]
283 }
284}
285
286#[derive(Copy, Clone)]
288pub(crate) struct OffloadKernelGlobals<'ll> {
289 pub offload_sizes: &'ll llvm::Value,
290 pub memtransfer_types: &'ll llvm::Value,
291 pub region_id: &'ll llvm::Value,
292 pub offload_entry: &'ll llvm::Value,
293}
294
295fn gen_tgt_data_mappers<'ll>(
296 cx: &CodegenCx<'ll, '_>,
297) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
298 let tptr = cx.type_ptr();
299 let ti64 = cx.type_i64();
300 let ti32 = cx.type_i32();
301
302 let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
303 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
304 let mapper_begin = "__tgt_target_data_begin_mapper";
305 let mapper_update = "__tgt_target_data_update_mapper";
306 let mapper_end = "__tgt_target_data_end_mapper";
307 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
308 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
309 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
310
311 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
312 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
313 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
314 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
315
316 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
317}
318
319fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
320 let ti64 = cx.type_i64();
321 let mut size_val = Vec::with_capacity(vals.len());
322 for &val in vals {
323 size_val.push(cx.get_const_i64(val));
324 }
325 let initializer = cx.const_array(ti64, &size_val);
326 add_unnamed_global(cx, name, initializer, PrivateLinkage)
327}
328
329pub(crate) fn add_unnamed_global<'ll>(
330 cx: &SimpleCx<'ll>,
331 name: &str,
332 initializer: &'ll llvm::Value,
333 l: Linkage,
334) -> &'ll llvm::Value {
335 let llglobal = add_global(cx, name, initializer, l);
336 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
337 llglobal
338}
339
340pub(crate) fn add_global<'ll>(
341 cx: &SimpleCx<'ll>,
342 name: &str,
343 initializer: &'ll llvm::Value,
344 l: Linkage,
345) -> &'ll llvm::Value {
346 let c_name = CString::new(name).unwrap();
347 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
348 llvm::set_global_constant(llglobal, true);
349 llvm::set_linkage(llglobal, l);
350 llvm::set_initializer(llglobal, initializer);
351 llglobal
352}
353
354pub(crate) fn gen_define_handling<'ll>(
358 cx: &CodegenCx<'ll, '_>,
359 metadata: &[OffloadMetadata],
360 types: &[&'ll Type],
361 symbol: String,
362 offload_globals: &OffloadGlobals<'ll>,
363) -> OffloadKernelGlobals<'ll> {
364 if let Some(entry) = cx.offload_kernel_cache.borrow().get(&symbol) {
365 return *entry;
366 }
367
368 let offload_entry_ty = offload_globals.offload_entry_ty;
369
370 let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
373 rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
374 _ => None,
375 });
376
377 let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
379 ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
380
381 let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
382 let memtransfer_types =
388 add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer);
389
390 let name = format!(".{symbol}.region_id");
394 let initializer = cx.get_const_i8(0);
395 let region_id = add_global(&cx, &name, initializer, WeakAnyLinkage);
396
397 let c_entry_name = CString::new(symbol.clone()).unwrap();
398 let c_val = c_entry_name.as_bytes_with_nul();
399 let offload_entry_name = format!(".offloading.entry_name.{symbol}");
400
401 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
402 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
403 llvm::set_alignment(llglobal, Align::ONE);
404 llvm::set_section(llglobal, c".llvm.rodata.offloading");
405
406 let name = format!(".offloading.entry.{symbol}");
407
408 let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
410
411 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
412 let c_name = CString::new(name).unwrap();
413 let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
414 llvm::set_global_constant(offload_entry, true);
415 llvm::set_linkage(offload_entry, WeakAnyLinkage);
416 llvm::set_initializer(offload_entry, initializer);
417 llvm::set_alignment(offload_entry, Align::EIGHT);
418 let c_section_name = CString::new("llvm_offload_entries").unwrap();
419 llvm::set_section(offload_entry, &c_section_name);
420
421 let result =
422 OffloadKernelGlobals { offload_sizes, memtransfer_types, region_id, offload_entry };
423
424 cx.offload_kernel_cache.borrow_mut().insert(symbol, result);
425
426 result
427}
428
429fn declare_offload_fn<'ll>(
430 cx: &CodegenCx<'ll, '_>,
431 name: &str,
432 ty: &'ll llvm::Type,
433) -> &'ll llvm::Value {
434 crate::declare::declare_simple_fn(
435 cx,
436 name,
437 llvm::CallConv::CCallConv,
438 llvm::UnnamedAddr::No,
439 llvm::Visibility::Default,
440 ty,
441 )
442}
443
444pub(crate) fn gen_call_handling<'ll, 'tcx>(
464 builder: &mut Builder<'_, 'll, 'tcx>,
465 offload_data: &OffloadKernelGlobals<'ll>,
466 args: &[&'ll Value],
467 types: &[&Type],
468 metadata: &[OffloadMetadata],
469 offload_globals: &OffloadGlobals<'ll>,
470 offload_dims: &OffloadKernelDims<'ll>,
471) {
472 let cx = builder.cx;
473 let OffloadKernelGlobals { offload_sizes, offload_entry, memtransfer_types, region_id } =
474 offload_data;
475 let OffloadKernelDims { num_workgroups, threads_per_block, workgroup_dims, thread_dims } =
476 offload_dims;
477
478 let tgt_decl = offload_globals.launcher_fn;
479 let tgt_target_kernel_ty = offload_globals.launcher_ty;
480
481 let tgt_bin_desc = offload_globals.bin_desc;
483
484 let tgt_kernel_decl = offload_globals.kernel_args_ty;
485 let begin_mapper_decl = offload_globals.begin_mapper;
486 let end_mapper_decl = offload_globals.end_mapper;
487 let fn_ty = offload_globals.mapper_fn_ty;
488
489 let num_args = types.len() as u64;
490 let bb = builder.llbb();
491
492 for val in [offload_sizes, offload_entry] {
495 unsafe {
496 let dummy = llvm::LLVMBuildLoad2(
497 &builder.llbuilder,
498 llvm::LLVMTypeOf(val),
499 val,
500 b"dummy\0".as_ptr() as *const _,
501 );
502 llvm::LLVMSetVolatile(dummy, llvm::TRUE);
503 }
504 }
505
506 unsafe {
510 llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, builder.llfn());
511 }
512 let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
513
514 let ty = cx.type_array(cx.type_ptr(), num_args);
515 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
517 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
519 let ty2 = cx.type_array(cx.type_i64(), num_args);
521 let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
522
523 let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
525
526 unsafe {
528 llvm::LLVMPositionBuilderAtEnd(&builder.llbuilder, bb);
529 }
530 builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
531
532 let mut vals = vec![];
534 let mut geps = vec![];
535 let i32_0 = cx.get_const_i32(0);
536 for &v in args {
537 let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
538 vals.push(v);
539 geps.push(gep);
540 }
541
542 let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
543 let register_lib_decl = offload_globals.register_lib;
544 let unregister_lib_decl = offload_globals.unregister_lib;
545 let init_ty = cx.type_func(&[], cx.type_void());
546 let init_rtls_decl = offload_globals.init_rtls;
547
548 builder.call(mapper_fn_ty, None, None, register_lib_decl, &[tgt_bin_desc_alloca], None, None);
551 builder.call(init_ty, None, None, init_rtls_decl, &[], None, None);
553
554 for i in 0..num_args {
555 let idx = cx.get_const_i32(i);
556 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
557 builder.store(vals[i as usize], gep1, Align::EIGHT);
558 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
559 builder.store(geps[i as usize], gep2, Align::EIGHT);
560 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
561 builder.store(cx.get_const_i64(metadata[i as usize].payload_size), gep3, Align::EIGHT);
563 }
564
565 fn get_geps<'ll, 'tcx>(
568 builder: &mut Builder<'_, 'll, 'tcx>,
569 ty: &'ll Type,
570 ty2: &'ll Type,
571 a1: &'ll Value,
572 a2: &'ll Value,
573 a4: &'ll Value,
574 ) -> [&'ll Value; 3] {
575 let cx = builder.cx;
576 let i32_0 = cx.get_const_i32(0);
577
578 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
579 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
580 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
581 [gep1, gep2, gep3]
582 }
583
584 fn generate_mapper_call<'ll, 'tcx>(
585 builder: &mut Builder<'_, 'll, 'tcx>,
586 geps: [&'ll Value; 3],
587 o_type: &'ll Value,
588 fn_to_call: &'ll Value,
589 fn_ty: &'ll Type,
590 num_args: u64,
591 s_ident_t: &'ll Value,
592 ) {
593 let cx = builder.cx;
594 let nullptr = cx.const_null(cx.type_ptr());
595 let i64_max = cx.get_const_i64(u64::MAX);
596 let num_args = cx.get_const_i32(num_args);
597 let args =
598 vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
599 builder.call(fn_ty, None, None, fn_to_call, &args, None, None);
600 }
601
602 let s_ident_t = offload_globals.ident_t_global;
604 let geps = get_geps(builder, ty, ty2, a1, a2, a4);
605 generate_mapper_call(
606 builder,
607 geps,
608 memtransfer_types,
609 begin_mapper_decl,
610 fn_ty,
611 num_args,
612 s_ident_t,
613 );
614 let values =
615 KernelArgsTy::new(&cx, num_args, memtransfer_types, geps, workgroup_dims, thread_dims);
616
617 for (i, value) in values.iter().enumerate() {
620 let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
621 builder.store(value.1, ptr, value.0);
622 }
623
624 let args = vec![
625 s_ident_t,
626 cx.get_const_i64(u64::MAX), num_workgroups,
629 threads_per_block,
630 region_id,
631 a5,
632 ];
633 builder.call(tgt_target_kernel_ty, None, None, tgt_decl, &args, None, None);
634 let geps = get_geps(builder, ty, ty2, a1, a2, a4);
638 generate_mapper_call(
639 builder,
640 geps,
641 memtransfer_types,
642 end_mapper_decl,
643 fn_ty,
644 num_args,
645 s_ident_t,
646 );
647
648 builder.call(mapper_fn_ty, None, None, unregister_lib_decl, &[tgt_bin_desc_alloca], None, None);
649}