1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::back::write::CodegenContext;
6use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7
8use crate::builder::SBuilder;
9use crate::common::AsCCharPtr;
10use crate::llvm::AttributePlace::Function;
11use crate::llvm::{self, Linkage, Type, Value};
12use crate::{LlvmCodegenBackend, SimpleCx, attributes};
13
14pub(crate) fn handle_gpu_code<'ll>(
15 _cgcx: &CodegenContext<LlvmCodegenBackend>,
16 cx: &'ll SimpleCx<'_>,
17) {
18 let mut memtransfer_types = vec![];
20 let mut region_ids = vec![];
21 let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
22 for num in 0..9 {
26 let kernel = cx.get_function(&format!("kernel_{num}"));
27 if let Some(kernel) = kernel {
28 let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num);
29 memtransfer_types.push(o);
30 region_ids.push(k);
31 }
32 }
33
34 gen_call_handling(&cx, &memtransfer_types, ®ion_ids);
35}
36
37fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
40 let tptr = cx.type_ptr();
41 let ti64 = cx.type_i64();
42 let ti32 = cx.type_i32();
43 let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
44 let tgt_fn_ty = cx.type_func(&args, ti32);
45 let name = "__tgt_target_kernel";
46 let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
47 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
48 attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
49 (tgt_decl, tgt_fn_ty)
50}
51
52fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
58 let unknown_txt = ";unknown;unknown;0;0;;";
59 let c_entry_name = CString::new(unknown_txt).unwrap();
60 let c_val = c_entry_name.as_bytes_with_nul();
61 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
62 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
63 llvm::set_alignment(at_zero, Align::ONE);
64
65 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
67 let struct_elems = vec![
68 cx.get_const_i32(0),
69 cx.get_const_i32(2),
70 cx.get_const_i32(0),
71 cx.get_const_i32(22),
72 at_zero,
73 ];
74 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
75 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
76 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
77 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
78 llvm::set_alignment(at_one, Align::EIGHT);
79 at_one
80}
81
82struct TgtOffloadEntry {
83 }
93
94impl TgtOffloadEntry {
95 pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
96 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
97 let tptr = cx.type_ptr();
98 let ti64 = cx.type_i64();
99 let ti32 = cx.type_i32();
100 let ti16 = cx.type_i16();
101 let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
104 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
105 offload_entry_ty
106 }
107
108 fn new<'ll>(
109 cx: &'ll SimpleCx<'_>,
110 region_id: &'ll Value,
111 llglobal: &'ll Value,
112 ) -> [&'ll Value; 9] {
113 let reserved = cx.get_const_i64(0);
114 let version = cx.get_const_i16(1);
115 let kind = cx.get_const_i16(1);
116 let flags = cx.get_const_i32(0);
117 let size = cx.get_const_i64(0);
118 let data = cx.get_const_i64(0);
119 let aux_addr = cx.const_null(cx.type_ptr());
120 [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
121 }
122}
123
124struct KernelArgsTy {
126 }
148
149impl KernelArgsTy {
150 const OFFLOAD_VERSION: u64 = 3;
151 const FLAGS: u64 = 0;
152 const TRIPCOUNT: u64 = 0;
153 fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
154 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
155 let tptr = cx.type_ptr();
156 let ti64 = cx.type_i64();
157 let ti32 = cx.type_i32();
158 let tarr = cx.type_array(ti32, 3);
159
160 let kernel_elements =
161 vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
162
163 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
164 kernel_arguments_ty
165 }
166
167 fn new<'ll>(
168 cx: &'ll SimpleCx<'_>,
169 num_args: u64,
170 memtransfer_types: &[&'ll Value],
171 geps: [&'ll Value; 3],
172 ) -> [(Align, &'ll Value); 13] {
173 let four = Align::from_bytes(4).expect("4 Byte alignment should work");
174 let eight = Align::EIGHT;
175
176 let ti32 = cx.type_i32();
177 let ci32_0 = cx.get_const_i32(0);
178 [
179 (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
180 (four, cx.get_const_i32(num_args)),
181 (eight, geps[0]),
182 (eight, geps[1]),
183 (eight, geps[2]),
184 (eight, memtransfer_types[0]),
185 (eight, cx.const_null(cx.type_ptr())), (eight, cx.const_null(cx.type_ptr())), (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
189 (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
190 (four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
191 (four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
192 (four, cx.get_const_i32(0)),
193 ]
194 }
195}
196
197fn gen_tgt_data_mappers<'ll>(
198 cx: &'ll SimpleCx<'_>,
199) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
200 let tptr = cx.type_ptr();
201 let ti64 = cx.type_i64();
202 let ti32 = cx.type_i32();
203
204 let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
205 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
206 let mapper_begin = "__tgt_target_data_begin_mapper";
207 let mapper_update = "__tgt_target_data_update_mapper";
208 let mapper_end = "__tgt_target_data_end_mapper";
209 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
210 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
211 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
212
213 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
214 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
215 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
216 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
217
218 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
219}
220
221fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
222 let ti64 = cx.type_i64();
223 let mut size_val = Vec::with_capacity(vals.len());
224 for &val in vals {
225 size_val.push(cx.get_const_i64(val));
226 }
227 let initializer = cx.const_array(ti64, &size_val);
228 add_unnamed_global(cx, name, initializer, PrivateLinkage)
229}
230
231pub(crate) fn add_unnamed_global<'ll>(
232 cx: &SimpleCx<'ll>,
233 name: &str,
234 initializer: &'ll llvm::Value,
235 l: Linkage,
236) -> &'ll llvm::Value {
237 let llglobal = add_global(cx, name, initializer, l);
238 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
239 llglobal
240}
241
242pub(crate) fn add_global<'ll>(
243 cx: &SimpleCx<'ll>,
244 name: &str,
245 initializer: &'ll llvm::Value,
246 l: Linkage,
247) -> &'ll llvm::Value {
248 let c_name = CString::new(name).unwrap();
249 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
250 llvm::set_global_constant(llglobal, true);
251 llvm::set_linkage(llglobal, l);
252 llvm::set_initializer(llglobal, initializer);
253 llglobal
254}
255
256fn gen_define_handling<'ll>(
260 cx: &'ll SimpleCx<'_>,
261 kernel: &'ll llvm::Value,
262 offload_entry_ty: &'ll llvm::Type,
263 num: i64,
264) -> (&'ll llvm::Value, &'ll llvm::Value) {
265 let types = cx.func_params_types(cx.get_type_of_global(kernel));
266 let num_ptr_types = types
269 .iter()
270 .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
271 .count();
272
273 add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
278 let memtransfer_types = add_priv_unnamed_arr(
284 &cx,
285 &format!(".offload_maptypes.{num}"),
286 &vec![1 + 2 + 32; num_ptr_types],
287 );
288 let name = format!(".kernel_{num}.region_id");
292 let initializer = cx.get_const_i8(0);
293 let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
294
295 let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
296 let c_val = c_entry_name.as_bytes_with_nul();
297 let offload_entry_name = format!(".offloading.entry_name.{num}");
298
299 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
300 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
301 llvm::set_alignment(llglobal, Align::ONE);
302 llvm::set_section(llglobal, c".llvm.rodata.offloading");
303 let name = format!(".offloading.entry.kernel_{num}");
304
305 let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
307
308 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
309 let c_name = CString::new(name).unwrap();
310 let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
311 llvm::set_global_constant(llglobal, true);
312 llvm::set_linkage(llglobal, WeakAnyLinkage);
313 llvm::set_initializer(llglobal, initializer);
314 llvm::set_alignment(llglobal, Align::EIGHT);
315 let c_section_name = CString::new("llvm_offload_entries").unwrap();
316 llvm::set_section(llglobal, &c_section_name);
317 (memtransfer_types, region_id)
318}
319
320pub(crate) fn declare_offload_fn<'ll>(
321 cx: &'ll SimpleCx<'_>,
322 name: &str,
323 ty: &'ll llvm::Type,
324) -> &'ll llvm::Value {
325 crate::declare::declare_simple_fn(
326 cx,
327 name,
328 llvm::CallConv::CCallConv,
329 llvm::UnnamedAddr::No,
330 llvm::Visibility::Default,
331 ty,
332 )
333}
334
335fn gen_call_handling<'ll>(
356 cx: &'ll SimpleCx<'_>,
357 memtransfer_types: &[&'ll llvm::Value],
358 region_ids: &[&'ll llvm::Value],
359) {
360 let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
361 let tptr = cx.type_ptr();
363 let ti32 = cx.type_i32();
364 let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
365 let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
366 cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
367
368 let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
369 let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
370
371 let main_fn = cx.get_function("main");
372 let Some(main_fn) = main_fn else { return };
373 let kernel_name = "kernel_1";
374 let call = unsafe {
375 llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
376 };
377 let Some(kernel_call) = call else {
378 return;
379 };
380 let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
381 let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
382 let mut builder = SBuilder::build(cx, kernel_call_bb);
383
384 let types = cx.func_params_types(cx.get_type_of_global(called));
385 let num_args = types.len() as u64;
386
387 unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
391
392 let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
393
394 let ty = cx.type_array(cx.type_ptr(), num_args);
395 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
397 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
399 let ty2 = cx.type_array(cx.type_i64(), num_args);
401 let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
402
403 let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
405
406 unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
408 builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
409
410 let mut vals = vec![];
412 let mut geps = vec![];
413 let i32_0 = cx.get_const_i32(0);
414 for index in 0..types.len() {
415 let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
416 let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
417 vals.push(v);
418 geps.push(gep);
419 }
420
421 let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
422 let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
423 let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
424 let init_ty = cx.type_func(&[], cx.type_void());
425 let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
426
427 builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
430 builder.call(init_ty, init_rtls_decl, &[], None);
432
433 for i in 0..num_args {
434 let idx = cx.get_const_i32(i);
435 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
436 builder.store(vals[i as usize], gep1, Align::EIGHT);
437 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
438 builder.store(geps[i as usize], gep2, Align::EIGHT);
439 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
440 builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
444 }
445
446 fn get_geps<'a, 'll>(
449 builder: &mut SBuilder<'a, 'll>,
450 cx: &'ll SimpleCx<'ll>,
451 ty: &'ll Type,
452 ty2: &'ll Type,
453 a1: &'ll Value,
454 a2: &'ll Value,
455 a4: &'ll Value,
456 ) -> [&'ll Value; 3] {
457 let i32_0 = cx.get_const_i32(0);
458
459 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
460 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
461 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
462 [gep1, gep2, gep3]
463 }
464
465 fn generate_mapper_call<'a, 'll>(
466 builder: &mut SBuilder<'a, 'll>,
467 cx: &'ll SimpleCx<'ll>,
468 geps: [&'ll Value; 3],
469 o_type: &'ll Value,
470 fn_to_call: &'ll Value,
471 fn_ty: &'ll Type,
472 num_args: u64,
473 s_ident_t: &'ll Value,
474 ) {
475 let nullptr = cx.const_null(cx.type_ptr());
476 let i64_max = cx.get_const_i64(u64::MAX);
477 let num_args = cx.get_const_i32(num_args);
478 let args =
479 vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
480 builder.call(fn_ty, fn_to_call, &args, None);
481 }
482
483 let s_ident_t = generate_at_one(&cx);
485 let o = memtransfer_types[0];
486 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
487 generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
488 let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
489
490 for (i, value) in values.iter().enumerate() {
493 let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
494 builder.store(value.1, ptr, value.0);
495 }
496
497 let args = vec![
498 s_ident_t,
499 cx.get_const_i64(u64::MAX), cx.get_const_i32(2097152),
503 cx.get_const_i32(256),
504 region_ids[0],
505 a5,
506 ];
507 let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
508 unsafe {
510 let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
511 llvm::LLVMRustPositionAfter(builder.llbuilder, next);
512 llvm::LLVMInstructionEraseFromParent(next);
513 }
514
515 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
517 generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
518
519 builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
520
521 drop(builder);
522 unsafe { llvm::LLVMDeleteFunction(called) };
526}