rustc_target/callconv/
nvptx64.rs

1use rustc_abi::{HasDataLayout, Reg, Size, TyAbiInterface};
2
3use super::{ArgAttribute, ArgAttributes, ArgExtension, CastTarget};
4use crate::callconv::{ArgAbi, FnAbi, Uniform};
5
6fn classify_ret<Ty>(ret: &mut ArgAbi<'_, Ty>) {
7    if ret.layout.is_aggregate() && ret.layout.is_sized() {
8        classify_aggregate(ret)
9    } else if ret.layout.size.bits() < 32 && ret.layout.is_sized() {
10        ret.extend_integer_width_to(32);
11    }
12}
13
14fn classify_arg<Ty>(arg: &mut ArgAbi<'_, Ty>) {
15    if arg.layout.is_aggregate() && arg.layout.is_sized() {
16        classify_aggregate(arg)
17    } else if arg.layout.size.bits() < 32 && arg.layout.is_sized() {
18        arg.extend_integer_width_to(32);
19    }
20}
21
22/// the pass mode used for aggregates in arg and ret position
23fn classify_aggregate<Ty>(arg: &mut ArgAbi<'_, Ty>) {
24    let align_bytes = arg.layout.align.abi.bytes();
25    let size = arg.layout.size;
26
27    let reg = match align_bytes {
28        1 => Reg::i8(),
29        2 => Reg::i16(),
30        4 => Reg::i32(),
31        8 => Reg::i64(),
32        16 => Reg::i128(),
33        _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
34    };
35
36    if align_bytes == size.bytes() {
37        arg.cast_to(CastTarget {
38            prefix: [Some(reg), None, None, None, None, None, None, None],
39            rest: Uniform::new(Reg::i8(), Size::from_bytes(0)),
40            attrs: ArgAttributes {
41                regular: ArgAttribute::default(),
42                arg_ext: ArgExtension::None,
43                pointee_size: Size::ZERO,
44                pointee_align: None,
45            },
46        });
47    } else {
48        arg.cast_to(Uniform::new(reg, size));
49    }
50}
51
52fn classify_arg_kernel<'a, Ty, C>(_cx: &C, arg: &mut ArgAbi<'a, Ty>)
53where
54    Ty: TyAbiInterface<'a, C> + Copy,
55    C: HasDataLayout,
56{
57    match arg.mode {
58        super::PassMode::Ignore | super::PassMode::Direct(_) => return,
59        super::PassMode::Pair(_, _) => {}
60        super::PassMode::Cast { .. } => unreachable!(),
61        super::PassMode::Indirect { .. } => {}
62    }
63
64    // FIXME only allow structs and wide pointers here
65    // panic!(
66    //     "`extern \"ptx-kernel\"` doesn't allow passing types other than primitives and structs"
67    // );
68
69    let align_bytes = arg.layout.align.abi.bytes();
70
71    let unit = match align_bytes {
72        1 => Reg::i8(),
73        2 => Reg::i16(),
74        4 => Reg::i32(),
75        8 => Reg::i64(),
76        16 => Reg::i128(),
77        _ => unreachable!("Align is given as power of 2 no larger than 16 bytes"),
78    };
79    if arg.layout.size.bytes() / align_bytes == 1 {
80        // Make sure we pass the struct as array at the LLVM IR level and not as a single integer.
81        arg.cast_to(CastTarget {
82            prefix: [Some(unit), None, None, None, None, None, None, None],
83            rest: Uniform::new(unit, Size::ZERO),
84            attrs: ArgAttributes::new(),
85        });
86    } else {
87        arg.cast_to(Uniform::new(unit, arg.layout.size));
88    }
89}
90
91pub(crate) fn compute_abi_info<Ty>(fn_abi: &mut FnAbi<'_, Ty>) {
92    if !fn_abi.ret.is_ignore() {
93        classify_ret(&mut fn_abi.ret);
94    }
95
96    for arg in fn_abi.args.iter_mut() {
97        if arg.is_ignore() {
98            continue;
99        }
100        classify_arg(arg);
101    }
102}
103
104pub(crate) fn compute_ptx_kernel_abi_info<'a, Ty, C>(cx: &C, fn_abi: &mut FnAbi<'a, Ty>)
105where
106    Ty: TyAbiInterface<'a, C> + Copy,
107    C: HasDataLayout,
108{
109    if !fn_abi.ret.layout.is_unit() && !fn_abi.ret.layout.is_never() {
110        panic!("Kernels should not return anything other than () or !");
111    }
112
113    for arg in fn_abi.args.iter_mut() {
114        if arg.is_ignore() {
115            continue;
116        }
117        classify_arg_kernel(cx, arg);
118    }
119}