1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
6use rustc_middle::ty::offload_meta::OffloadMetadata;
7
8use crate::builder::SBuilder;
9use crate::llvm::AttributePlace::Function;
10use crate::llvm::{self, BasicBlock, Linkage, Type, Value};
11use crate::{SimpleCx, attributes};
12
13fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
16 let tptr = cx.type_ptr();
17 let ti64 = cx.type_i64();
18 let ti32 = cx.type_i32();
19 let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
20 let tgt_fn_ty = cx.type_func(&args, ti32);
21 let name = "__tgt_target_kernel";
22 let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
23 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
24 attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
25 (tgt_decl, tgt_fn_ty)
26}
27
28fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
34 let unknown_txt = ";unknown;unknown;0;0;;";
35 let c_entry_name = CString::new(unknown_txt).unwrap();
36 let c_val = c_entry_name.as_bytes_with_nul();
37 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
38 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
39 llvm::set_alignment(at_zero, Align::ONE);
40
41 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
43 let struct_elems = vec![
44 cx.get_const_i32(0),
45 cx.get_const_i32(2),
46 cx.get_const_i32(0),
47 cx.get_const_i32(22),
48 at_zero,
49 ];
50 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
51 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
52 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
53 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
54 llvm::set_alignment(at_one, Align::EIGHT);
55 at_one
56}
57
58pub(crate) struct TgtOffloadEntry {
59 }
69
70impl TgtOffloadEntry {
71 pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
72 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
73 let tptr = cx.type_ptr();
74 let ti64 = cx.type_i64();
75 let ti32 = cx.type_i32();
76 let ti16 = cx.type_i16();
77 let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
80 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
81 offload_entry_ty
82 }
83
84 fn new<'ll>(
85 cx: &'ll SimpleCx<'_>,
86 region_id: &'ll Value,
87 llglobal: &'ll Value,
88 ) -> [&'ll Value; 9] {
89 let reserved = cx.get_const_i64(0);
90 let version = cx.get_const_i16(1);
91 let kind = cx.get_const_i16(1);
92 let flags = cx.get_const_i32(0);
93 let size = cx.get_const_i64(0);
94 let data = cx.get_const_i64(0);
95 let aux_addr = cx.const_null(cx.type_ptr());
96 [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
97 }
98}
99
100struct KernelArgsTy {
102 }
124
125impl KernelArgsTy {
126 const OFFLOAD_VERSION: u64 = 3;
127 const FLAGS: u64 = 0;
128 const TRIPCOUNT: u64 = 0;
129 fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
130 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
131 let tptr = cx.type_ptr();
132 let ti64 = cx.type_i64();
133 let ti32 = cx.type_i32();
134 let tarr = cx.type_array(ti32, 3);
135
136 let kernel_elements =
137 vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
138
139 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
140 kernel_arguments_ty
141 }
142
143 fn new<'ll>(
144 cx: &'ll SimpleCx<'_>,
145 num_args: u64,
146 memtransfer_types: &'ll Value,
147 geps: [&'ll Value; 3],
148 ) -> [(Align, &'ll Value); 13] {
149 let four = Align::from_bytes(4).expect("4 Byte alignment should work");
150 let eight = Align::EIGHT;
151
152 let ti32 = cx.type_i32();
153 let ci32_0 = cx.get_const_i32(0);
154 [
155 (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
156 (four, cx.get_const_i32(num_args)),
157 (eight, geps[0]),
158 (eight, geps[1]),
159 (eight, geps[2]),
160 (eight, memtransfer_types),
161 (eight, cx.const_null(cx.type_ptr())), (eight, cx.const_null(cx.type_ptr())), (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
165 (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
166 (four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
167 (four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
168 (four, cx.get_const_i32(0)),
169 ]
170 }
171}
172
173pub(crate) struct OffloadKernelData<'ll> {
175 pub offload_sizes: &'ll llvm::Value,
176 pub memtransfer_types: &'ll llvm::Value,
177 pub region_id: &'ll llvm::Value,
178 pub offload_entry: &'ll llvm::Value,
179}
180
181fn gen_tgt_data_mappers<'ll>(
182 cx: &'ll SimpleCx<'_>,
183) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
184 let tptr = cx.type_ptr();
185 let ti64 = cx.type_i64();
186 let ti32 = cx.type_i32();
187
188 let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
189 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
190 let mapper_begin = "__tgt_target_data_begin_mapper";
191 let mapper_update = "__tgt_target_data_update_mapper";
192 let mapper_end = "__tgt_target_data_end_mapper";
193 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
194 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
195 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
196
197 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
198 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
199 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
200 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
201
202 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
203}
204
205fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
206 let ti64 = cx.type_i64();
207 let mut size_val = Vec::with_capacity(vals.len());
208 for &val in vals {
209 size_val.push(cx.get_const_i64(val));
210 }
211 let initializer = cx.const_array(ti64, &size_val);
212 add_unnamed_global(cx, name, initializer, PrivateLinkage)
213}
214
215pub(crate) fn add_unnamed_global<'ll>(
216 cx: &SimpleCx<'ll>,
217 name: &str,
218 initializer: &'ll llvm::Value,
219 l: Linkage,
220) -> &'ll llvm::Value {
221 let llglobal = add_global(cx, name, initializer, l);
222 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
223 llglobal
224}
225
226pub(crate) fn add_global<'ll>(
227 cx: &SimpleCx<'ll>,
228 name: &str,
229 initializer: &'ll llvm::Value,
230 l: Linkage,
231) -> &'ll llvm::Value {
232 let c_name = CString::new(name).unwrap();
233 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
234 llvm::set_global_constant(llglobal, true);
235 llvm::set_linkage(llglobal, l);
236 llvm::set_initializer(llglobal, initializer);
237 llglobal
238}
239
240pub(crate) fn gen_define_handling<'ll>(
244 cx: &SimpleCx<'ll>,
245 offload_entry_ty: &'ll llvm::Type,
246 metadata: &[OffloadMetadata],
247 types: &[&Type],
248 symbol: &str,
249) -> OffloadKernelData<'ll> {
250 let ptr_meta = types.iter().zip(metadata).filter_map(|(&x, meta)| match cx.type_kind(x) {
253 rustc_codegen_ssa::common::TypeKind::Pointer => Some(meta),
254 _ => None,
255 });
256
257 let (ptr_sizes, ptr_transfer): (Vec<_>, Vec<_>) =
259 ptr_meta.map(|m| (m.payload_size, m.mode.bits() | 0x20)).unzip();
260
261 let offload_sizes = add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{symbol}"), &ptr_sizes);
262 let memtransfer_types =
268 add_priv_unnamed_arr(&cx, &format!(".offload_maptypes.{symbol}"), &ptr_transfer);
269
270 let name = format!(".{symbol}.region_id");
274 let initializer = cx.get_const_i8(0);
275 let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
276
277 let c_entry_name = CString::new(symbol).unwrap();
278 let c_val = c_entry_name.as_bytes_with_nul();
279 let offload_entry_name = format!(".offloading.entry_name.{symbol}");
280
281 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
282 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
283 llvm::set_alignment(llglobal, Align::ONE);
284 llvm::set_section(llglobal, c".llvm.rodata.offloading");
285
286 let name = format!(".offloading.entry.{symbol}");
287
288 let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
290
291 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
292 let c_name = CString::new(name).unwrap();
293 let offload_entry = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
294 llvm::set_global_constant(offload_entry, true);
295 llvm::set_linkage(offload_entry, WeakAnyLinkage);
296 llvm::set_initializer(offload_entry, initializer);
297 llvm::set_alignment(offload_entry, Align::EIGHT);
298 let c_section_name = CString::new("llvm_offload_entries").unwrap();
299 llvm::set_section(offload_entry, &c_section_name);
300
301 OffloadKernelData { offload_sizes, memtransfer_types, region_id, offload_entry }
302}
303
304fn declare_offload_fn<'ll>(
305 cx: &'ll SimpleCx<'_>,
306 name: &str,
307 ty: &'ll llvm::Type,
308) -> &'ll llvm::Value {
309 crate::declare::declare_simple_fn(
310 cx,
311 name,
312 llvm::CallConv::CCallConv,
313 llvm::UnnamedAddr::No,
314 llvm::Visibility::Default,
315 ty,
316 )
317}
318
319pub(crate) fn gen_call_handling<'ll>(
339 cx: &SimpleCx<'ll>,
340 bb: &BasicBlock,
341 offload_data: &OffloadKernelData<'ll>,
342 args: &[&'ll Value],
343 types: &[&Type],
344 metadata: &[OffloadMetadata],
345) {
346 let OffloadKernelData { offload_sizes, offload_entry, memtransfer_types, region_id } =
347 offload_data;
348 let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
349 let tptr = cx.type_ptr();
351 let ti32 = cx.type_i32();
352 let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
353 let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
354 cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
355
356 let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
357 let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
358
359 let mut builder = SBuilder::build(cx, bb);
360
361 let num_args = types.len() as u64;
362 let ip = unsafe { llvm::LLVMRustGetInsertPoint(&builder.llbuilder) };
363
364 for val in [offload_sizes, offload_entry] {
367 unsafe {
368 let dummy = llvm::LLVMBuildLoad2(
369 &builder.llbuilder,
370 llvm::LLVMTypeOf(val),
371 val,
372 b"dummy\0".as_ptr() as *const _,
373 );
374 llvm::LLVMSetVolatile(dummy, llvm::TRUE);
375 }
376 }
377
378 let llfn = unsafe { llvm::LLVMGetBasicBlockParent(bb) };
382 unsafe {
383 llvm::LLVMRustPositionBuilderPastAllocas(&builder.llbuilder, llfn);
384 }
385 let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
386
387 let ty = cx.type_array(cx.type_ptr(), num_args);
388 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
390 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
392 let ty2 = cx.type_array(cx.type_i64(), num_args);
394 let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
395
396 let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
398
399 unsafe {
401 llvm::LLVMRustRestoreInsertPoint(&builder.llbuilder, ip);
402 }
403 builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
404
405 let mut vals = vec![];
407 let mut geps = vec![];
408 let i32_0 = cx.get_const_i32(0);
409 for &v in args {
410 let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
411 vals.push(v);
412 geps.push(gep);
413 }
414
415 let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
416 let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
417 let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
418 let init_ty = cx.type_func(&[], cx.type_void());
419 let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
420
421 builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
424 builder.call(init_ty, init_rtls_decl, &[], None);
426
427 for i in 0..num_args {
428 let idx = cx.get_const_i32(i);
429 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
430 builder.store(vals[i as usize], gep1, Align::EIGHT);
431 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
432 builder.store(geps[i as usize], gep2, Align::EIGHT);
433 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
434 builder.store(cx.get_const_i64(metadata[i as usize].payload_size), gep3, Align::EIGHT);
436 }
437
438 fn get_geps<'a, 'll>(
441 builder: &mut SBuilder<'a, 'll>,
442 cx: &'ll SimpleCx<'ll>,
443 ty: &'ll Type,
444 ty2: &'ll Type,
445 a1: &'ll Value,
446 a2: &'ll Value,
447 a4: &'ll Value,
448 ) -> [&'ll Value; 3] {
449 let i32_0 = cx.get_const_i32(0);
450
451 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
452 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
453 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
454 [gep1, gep2, gep3]
455 }
456
457 fn generate_mapper_call<'a, 'll>(
458 builder: &mut SBuilder<'a, 'll>,
459 cx: &'ll SimpleCx<'ll>,
460 geps: [&'ll Value; 3],
461 o_type: &'ll Value,
462 fn_to_call: &'ll Value,
463 fn_ty: &'ll Type,
464 num_args: u64,
465 s_ident_t: &'ll Value,
466 ) {
467 let nullptr = cx.const_null(cx.type_ptr());
468 let i64_max = cx.get_const_i64(u64::MAX);
469 let num_args = cx.get_const_i32(num_args);
470 let args =
471 vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
472 builder.call(fn_ty, fn_to_call, &args, None);
473 }
474
475 let s_ident_t = generate_at_one(&cx);
477 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
478 generate_mapper_call(
479 &mut builder,
480 &cx,
481 geps,
482 memtransfer_types,
483 begin_mapper_decl,
484 fn_ty,
485 num_args,
486 s_ident_t,
487 );
488 let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
489
490 for (i, value) in values.iter().enumerate() {
493 let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
494 builder.store(value.1, ptr, value.0);
495 }
496
497 let args = vec![
498 s_ident_t,
499 cx.get_const_i64(u64::MAX), cx.get_const_i32(2097152),
503 cx.get_const_i32(256),
504 region_id,
505 a5,
506 ];
507 builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
508 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
512 generate_mapper_call(
513 &mut builder,
514 &cx,
515 geps,
516 memtransfer_types,
517 end_mapper_decl,
518 fn_ty,
519 num_args,
520 s_ident_t,
521 );
522
523 builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
524
525 drop(builder);
526}