rustc_codegen_llvm/builder/
gpu_offload.rs1use std::ffi::CString;
2
3use llvm::Linkage::*;
4use rustc_abi::Align;
5use rustc_codegen_ssa::back::write::CodegenContext;
6use rustc_codegen_ssa::traits::BaseTypeCodegenMethods;
7
8use crate::builder::SBuilder;
9use crate::common::AsCCharPtr;
10use crate::llvm::AttributePlace::Function;
11use crate::llvm::{self, Linkage, Type, Value};
12use crate::{LlvmCodegenBackend, SimpleCx, attributes};
13
14pub(crate) fn handle_gpu_code<'ll>(
15 _cgcx: &CodegenContext<LlvmCodegenBackend>,
16 cx: &'ll SimpleCx<'_>,
17) {
18 let mut memtransfer_types = vec![];
20 let mut region_ids = vec![];
21 let offload_entry_ty = TgtOffloadEntry::new_decl(&cx);
22 for num in 0..9 {
23 let kernel = cx.get_function(&format!("kernel_{num}"));
24 if let Some(kernel) = kernel {
25 let (o, k) = gen_define_handling(&cx, kernel, offload_entry_ty, num);
26 memtransfer_types.push(o);
27 region_ids.push(k);
28 }
29 }
30
31 gen_call_handling(&cx, &memtransfer_types, ®ion_ids);
32}
33
34fn generate_launcher<'ll>(cx: &'ll SimpleCx<'_>) -> (&'ll llvm::Value, &'ll llvm::Type) {
37 let tptr = cx.type_ptr();
38 let ti64 = cx.type_i64();
39 let ti32 = cx.type_i32();
40 let args = vec![tptr, ti64, ti32, ti32, tptr, tptr];
41 let tgt_fn_ty = cx.type_func(&args, ti32);
42 let name = "__tgt_target_kernel";
43 let tgt_decl = declare_offload_fn(&cx, name, tgt_fn_ty);
44 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
45 attributes::apply_to_llfn(tgt_decl, Function, &[nounwind]);
46 (tgt_decl, tgt_fn_ty)
47}
48
49fn generate_at_one<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Value {
55 let unknown_txt = ";unknown;unknown;0;0;;";
56 let c_entry_name = CString::new(unknown_txt).unwrap();
57 let c_val = c_entry_name.as_bytes_with_nul();
58 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
59 let at_zero = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
60 llvm::set_alignment(at_zero, Align::ONE);
61
62 let struct_ident_ty = cx.type_named_struct("struct.ident_t");
64 let struct_elems = vec![
65 cx.get_const_i32(0),
66 cx.get_const_i32(2),
67 cx.get_const_i32(0),
68 cx.get_const_i32(22),
69 at_zero,
70 ];
71 let struct_elems_ty: Vec<_> = struct_elems.iter().map(|&x| cx.val_ty(x)).collect();
72 let initializer = crate::common::named_struct(struct_ident_ty, &struct_elems);
73 cx.set_struct_body(struct_ident_ty, &struct_elems_ty, false);
74 let at_one = add_unnamed_global(&cx, &"", initializer, PrivateLinkage);
75 llvm::set_alignment(at_one, Align::EIGHT);
76 at_one
77}
78
79struct TgtOffloadEntry {
80 }
90
91impl TgtOffloadEntry {
92 pub(crate) fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll llvm::Type {
93 let offload_entry_ty = cx.type_named_struct("struct.__tgt_offload_entry");
94 let tptr = cx.type_ptr();
95 let ti64 = cx.type_i64();
96 let ti32 = cx.type_i32();
97 let ti16 = cx.type_i16();
98 let entry_elements = vec![ti64, ti16, ti16, ti32, tptr, tptr, ti64, ti64, tptr];
101 cx.set_struct_body(offload_entry_ty, &entry_elements, false);
102 offload_entry_ty
103 }
104
105 fn new<'ll>(
106 cx: &'ll SimpleCx<'_>,
107 region_id: &'ll Value,
108 llglobal: &'ll Value,
109 ) -> [&'ll Value; 9] {
110 let reserved = cx.get_const_i64(0);
111 let version = cx.get_const_i16(1);
112 let kind = cx.get_const_i16(1);
113 let flags = cx.get_const_i32(0);
114 let size = cx.get_const_i64(0);
115 let data = cx.get_const_i64(0);
116 let aux_addr = cx.const_null(cx.type_ptr());
117 [reserved, version, kind, flags, region_id, llglobal, size, data, aux_addr]
118 }
119}
120
121struct KernelArgsTy {
123 }
145
146impl KernelArgsTy {
147 const OFFLOAD_VERSION: u64 = 3;
148 const FLAGS: u64 = 0;
149 const TRIPCOUNT: u64 = 0;
150 fn new_decl<'ll>(cx: &'ll SimpleCx<'_>) -> &'ll Type {
151 let kernel_arguments_ty = cx.type_named_struct("struct.__tgt_kernel_arguments");
152 let tptr = cx.type_ptr();
153 let ti64 = cx.type_i64();
154 let ti32 = cx.type_i32();
155 let tarr = cx.type_array(ti32, 3);
156
157 let kernel_elements =
158 vec![ti32, ti32, tptr, tptr, tptr, tptr, tptr, tptr, ti64, ti64, tarr, tarr, ti32];
159
160 cx.set_struct_body(kernel_arguments_ty, &kernel_elements, false);
161 kernel_arguments_ty
162 }
163
164 fn new<'ll>(
165 cx: &'ll SimpleCx<'_>,
166 num_args: u64,
167 memtransfer_types: &[&'ll Value],
168 geps: [&'ll Value; 3],
169 ) -> [(Align, &'ll Value); 13] {
170 let four = Align::from_bytes(4).expect("4 Byte alignment should work");
171 let eight = Align::EIGHT;
172
173 let ti32 = cx.type_i32();
174 let ci32_0 = cx.get_const_i32(0);
175 [
176 (four, cx.get_const_i32(KernelArgsTy::OFFLOAD_VERSION)),
177 (four, cx.get_const_i32(num_args)),
178 (eight, geps[0]),
179 (eight, geps[1]),
180 (eight, geps[2]),
181 (eight, memtransfer_types[0]),
182 (eight, cx.const_null(cx.type_ptr())), (eight, cx.const_null(cx.type_ptr())), (eight, cx.get_const_i64(KernelArgsTy::TRIPCOUNT)),
186 (eight, cx.get_const_i64(KernelArgsTy::FLAGS)),
187 (four, cx.const_array(ti32, &[cx.get_const_i32(2097152), ci32_0, ci32_0])),
188 (four, cx.const_array(ti32, &[cx.get_const_i32(256), ci32_0, ci32_0])),
189 (four, cx.get_const_i32(0)),
190 ]
191 }
192}
193
194fn gen_tgt_data_mappers<'ll>(
195 cx: &'ll SimpleCx<'_>,
196) -> (&'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Value, &'ll llvm::Type) {
197 let tptr = cx.type_ptr();
198 let ti64 = cx.type_i64();
199 let ti32 = cx.type_i32();
200
201 let args = vec![tptr, ti64, ti32, tptr, tptr, tptr, tptr, tptr, tptr];
202 let mapper_fn_ty = cx.type_func(&args, cx.type_void());
203 let mapper_begin = "__tgt_target_data_begin_mapper";
204 let mapper_update = "__tgt_target_data_update_mapper";
205 let mapper_end = "__tgt_target_data_end_mapper";
206 let begin_mapper_decl = declare_offload_fn(&cx, mapper_begin, mapper_fn_ty);
207 let update_mapper_decl = declare_offload_fn(&cx, mapper_update, mapper_fn_ty);
208 let end_mapper_decl = declare_offload_fn(&cx, mapper_end, mapper_fn_ty);
209
210 let nounwind = llvm::AttributeKind::NoUnwind.create_attr(cx.llcx);
211 attributes::apply_to_llfn(begin_mapper_decl, Function, &[nounwind]);
212 attributes::apply_to_llfn(update_mapper_decl, Function, &[nounwind]);
213 attributes::apply_to_llfn(end_mapper_decl, Function, &[nounwind]);
214
215 (begin_mapper_decl, update_mapper_decl, end_mapper_decl, mapper_fn_ty)
216}
217
218fn add_priv_unnamed_arr<'ll>(cx: &SimpleCx<'ll>, name: &str, vals: &[u64]) -> &'ll llvm::Value {
219 let ti64 = cx.type_i64();
220 let mut size_val = Vec::with_capacity(vals.len());
221 for &val in vals {
222 size_val.push(cx.get_const_i64(val));
223 }
224 let initializer = cx.const_array(ti64, &size_val);
225 add_unnamed_global(cx, name, initializer, PrivateLinkage)
226}
227
228pub(crate) fn add_unnamed_global<'ll>(
229 cx: &SimpleCx<'ll>,
230 name: &str,
231 initializer: &'ll llvm::Value,
232 l: Linkage,
233) -> &'ll llvm::Value {
234 let llglobal = add_global(cx, name, initializer, l);
235 llvm::LLVMSetUnnamedAddress(llglobal, llvm::UnnamedAddr::Global);
236 llglobal
237}
238
239pub(crate) fn add_global<'ll>(
240 cx: &SimpleCx<'ll>,
241 name: &str,
242 initializer: &'ll llvm::Value,
243 l: Linkage,
244) -> &'ll llvm::Value {
245 let c_name = CString::new(name).unwrap();
246 let llglobal: &'ll llvm::Value = llvm::add_global(cx.llmod, cx.val_ty(initializer), &c_name);
247 llvm::set_global_constant(llglobal, true);
248 llvm::set_linkage(llglobal, l);
249 llvm::set_initializer(llglobal, initializer);
250 llglobal
251}
252
253fn gen_define_handling<'ll>(
257 cx: &'ll SimpleCx<'_>,
258 kernel: &'ll llvm::Value,
259 offload_entry_ty: &'ll llvm::Type,
260 num: i64,
261) -> (&'ll llvm::Value, &'ll llvm::Value) {
262 let types = cx.func_params_types(cx.get_type_of_global(kernel));
263 let num_ptr_types = types
266 .iter()
267 .filter(|&x| matches!(cx.type_kind(x), rustc_codegen_ssa::common::TypeKind::Pointer))
268 .count();
269
270 add_priv_unnamed_arr(&cx, &format!(".offload_sizes.{num}"), &vec![1024; num_ptr_types]);
275 let memtransfer_types = add_priv_unnamed_arr(
281 &cx,
282 &format!(".offload_maptypes.{num}"),
283 &vec![1 + 2 + 32; num_ptr_types],
284 );
285 let name = format!(".kernel_{num}.region_id");
289 let initializer = cx.get_const_i8(0);
290 let region_id = add_unnamed_global(&cx, &name, initializer, WeakAnyLinkage);
291
292 let c_entry_name = CString::new(format!("kernel_{num}")).unwrap();
293 let c_val = c_entry_name.as_bytes_with_nul();
294 let offload_entry_name = format!(".offloading.entry_name.{num}");
295
296 let initializer = crate::common::bytes_in_context(cx.llcx, c_val);
297 let llglobal = add_unnamed_global(&cx, &offload_entry_name, initializer, InternalLinkage);
298 llvm::set_alignment(llglobal, Align::ONE);
299 llvm::set_section(llglobal, c".llvm.rodata.offloading");
300 let name = format!(".offloading.entry.kernel_{num}");
301
302 let elems = TgtOffloadEntry::new(&cx, region_id, llglobal);
304
305 let initializer = crate::common::named_struct(offload_entry_ty, &elems);
306 let c_name = CString::new(name).unwrap();
307 let llglobal = llvm::add_global(cx.llmod, offload_entry_ty, &c_name);
308 llvm::set_global_constant(llglobal, true);
309 llvm::set_linkage(llglobal, WeakAnyLinkage);
310 llvm::set_initializer(llglobal, initializer);
311 llvm::set_alignment(llglobal, Align::EIGHT);
312 let c_section_name = CString::new("llvm_offload_entries").unwrap();
313 llvm::set_section(llglobal, &c_section_name);
314 (memtransfer_types, region_id)
315}
316
317pub(crate) fn declare_offload_fn<'ll>(
318 cx: &'ll SimpleCx<'_>,
319 name: &str,
320 ty: &'ll llvm::Type,
321) -> &'ll llvm::Value {
322 crate::declare::declare_simple_fn(
323 cx,
324 name,
325 llvm::CallConv::CCallConv,
326 llvm::UnnamedAddr::No,
327 llvm::Visibility::Default,
328 ty,
329 )
330}
331
332fn gen_call_handling<'ll>(
353 cx: &'ll SimpleCx<'_>,
354 memtransfer_types: &[&'ll llvm::Value],
355 region_ids: &[&'ll llvm::Value],
356) {
357 let (tgt_decl, tgt_target_kernel_ty) = generate_launcher(&cx);
358 let tptr = cx.type_ptr();
360 let ti32 = cx.type_i32();
361 let tgt_bin_desc_ty = vec![ti32, tptr, tptr, tptr];
362 let tgt_bin_desc = cx.type_named_struct("struct.__tgt_bin_desc");
363 cx.set_struct_body(tgt_bin_desc, &tgt_bin_desc_ty, false);
364
365 let tgt_kernel_decl = KernelArgsTy::new_decl(&cx);
366 let (begin_mapper_decl, _, end_mapper_decl, fn_ty) = gen_tgt_data_mappers(&cx);
367
368 let main_fn = cx.get_function("main");
369 let Some(main_fn) = main_fn else { return };
370 let kernel_name = "kernel_1";
371 let call = unsafe {
372 llvm::LLVMRustGetFunctionCall(main_fn, kernel_name.as_c_char_ptr(), kernel_name.len())
373 };
374 let Some(kernel_call) = call else {
375 return;
376 };
377 let kernel_call_bb = unsafe { llvm::LLVMGetInstructionParent(kernel_call) };
378 let called = unsafe { llvm::LLVMGetCalledValue(kernel_call).unwrap() };
379 let mut builder = SBuilder::build(cx, kernel_call_bb);
380
381 let types = cx.func_params_types(cx.get_type_of_global(called));
382 let num_args = types.len() as u64;
383
384 unsafe { llvm::LLVMRustPositionBuilderPastAllocas(builder.llbuilder, main_fn) };
388
389 let tgt_bin_desc_alloca = builder.direct_alloca(tgt_bin_desc, Align::EIGHT, "EmptyDesc");
390
391 let ty = cx.type_array(cx.type_ptr(), num_args);
392 let a1 = builder.direct_alloca(ty, Align::EIGHT, ".offload_baseptrs");
394 let a2 = builder.direct_alloca(ty, Align::EIGHT, ".offload_ptrs");
396 let ty2 = cx.type_array(cx.type_i64(), num_args);
398 let a4 = builder.direct_alloca(ty2, Align::EIGHT, ".offload_sizes");
399
400 let a5 = builder.direct_alloca(tgt_kernel_decl, Align::EIGHT, "kernel_args");
402
403 unsafe { llvm::LLVMRustPositionBefore(builder.llbuilder, kernel_call) };
405 builder.memset(tgt_bin_desc_alloca, cx.get_const_i8(0), cx.get_const_i64(32), Align::EIGHT);
406
407 let mut vals = vec![];
409 let mut geps = vec![];
410 let i32_0 = cx.get_const_i32(0);
411 for index in 0..types.len() {
412 let v = unsafe { llvm::LLVMGetOperand(kernel_call, index as u32).unwrap() };
413 let gep = builder.inbounds_gep(cx.type_f32(), v, &[i32_0]);
414 vals.push(v);
415 geps.push(gep);
416 }
417
418 let mapper_fn_ty = cx.type_func(&[cx.type_ptr()], cx.type_void());
419 let register_lib_decl = declare_offload_fn(&cx, "__tgt_register_lib", mapper_fn_ty);
420 let unregister_lib_decl = declare_offload_fn(&cx, "__tgt_unregister_lib", mapper_fn_ty);
421 let init_ty = cx.type_func(&[], cx.type_void());
422 let init_rtls_decl = declare_offload_fn(cx, "__tgt_init_all_rtls", init_ty);
423
424 builder.call(mapper_fn_ty, register_lib_decl, &[tgt_bin_desc_alloca], None);
427 builder.call(init_ty, init_rtls_decl, &[], None);
429
430 for i in 0..num_args {
431 let idx = cx.get_const_i32(i);
432 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, idx]);
433 builder.store(vals[i as usize], gep1, Align::EIGHT);
434 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, idx]);
435 builder.store(geps[i as usize], gep2, Align::EIGHT);
436 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, idx]);
437 builder.store(cx.get_const_i64(1024), gep3, Align::EIGHT);
441 }
442
443 fn get_geps<'a, 'll>(
446 builder: &mut SBuilder<'a, 'll>,
447 cx: &'ll SimpleCx<'ll>,
448 ty: &'ll Type,
449 ty2: &'ll Type,
450 a1: &'ll Value,
451 a2: &'ll Value,
452 a4: &'ll Value,
453 ) -> [&'ll Value; 3] {
454 let i32_0 = cx.get_const_i32(0);
455
456 let gep1 = builder.inbounds_gep(ty, a1, &[i32_0, i32_0]);
457 let gep2 = builder.inbounds_gep(ty, a2, &[i32_0, i32_0]);
458 let gep3 = builder.inbounds_gep(ty2, a4, &[i32_0, i32_0]);
459 [gep1, gep2, gep3]
460 }
461
462 fn generate_mapper_call<'a, 'll>(
463 builder: &mut SBuilder<'a, 'll>,
464 cx: &'ll SimpleCx<'ll>,
465 geps: [&'ll Value; 3],
466 o_type: &'ll Value,
467 fn_to_call: &'ll Value,
468 fn_ty: &'ll Type,
469 num_args: u64,
470 s_ident_t: &'ll Value,
471 ) {
472 let nullptr = cx.const_null(cx.type_ptr());
473 let i64_max = cx.get_const_i64(u64::MAX);
474 let num_args = cx.get_const_i32(num_args);
475 let args =
476 vec![s_ident_t, i64_max, num_args, geps[0], geps[1], geps[2], o_type, nullptr, nullptr];
477 builder.call(fn_ty, fn_to_call, &args, None);
478 }
479
480 let s_ident_t = generate_at_one(&cx);
482 let o = memtransfer_types[0];
483 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
484 generate_mapper_call(&mut builder, &cx, geps, o, begin_mapper_decl, fn_ty, num_args, s_ident_t);
485 let values = KernelArgsTy::new(&cx, num_args, memtransfer_types, geps);
486
487 for (i, value) in values.iter().enumerate() {
490 let ptr = builder.inbounds_gep(tgt_kernel_decl, a5, &[i32_0, cx.get_const_i32(i as u64)]);
491 builder.store(value.1, ptr, value.0);
492 }
493
494 let args = vec![
495 s_ident_t,
496 cx.get_const_i64(u64::MAX), cx.get_const_i32(2097152),
500 cx.get_const_i32(256),
501 region_ids[0],
502 a5,
503 ];
504 let offload_success = builder.call(tgt_target_kernel_ty, tgt_decl, &args, None);
505 unsafe {
507 let next = llvm::LLVMGetNextInstruction(offload_success).unwrap();
508 llvm::LLVMRustPositionAfter(builder.llbuilder, next);
509 llvm::LLVMInstructionEraseFromParent(next);
510 }
511
512 let geps = get_geps(&mut builder, &cx, ty, ty2, a1, a2, a4);
514 generate_mapper_call(&mut builder, &cx, geps, o, end_mapper_decl, fn_ty, num_args, s_ident_t);
515
516 builder.call(mapper_fn_ty, unregister_lib_decl, &[tgt_bin_desc_alloca], None);
517
518 drop(builder);
519 unsafe { llvm::LLVMDeleteFunction(called) };
523}