miri/shims/x86/
aesni.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
use rustc_middle::ty::Ty;
use rustc_middle::ty::layout::LayoutOf as _;
use rustc_span::Symbol;
use rustc_target::callconv::{Conv, FnAbi};

use crate::*;

impl<'tcx> EvalContextExt<'tcx> for crate::MiriInterpCx<'tcx> {}
pub(super) trait EvalContextExt<'tcx>: crate::MiriInterpCxExt<'tcx> {
    fn emulate_x86_aesni_intrinsic(
        &mut self,
        link_name: Symbol,
        abi: &FnAbi<'tcx, Ty<'tcx>>,
        args: &[OpTy<'tcx>],
        dest: &MPlaceTy<'tcx>,
    ) -> InterpResult<'tcx, EmulateItemResult> {
        let this = self.eval_context_mut();
        this.expect_target_feature_for_intrinsic(link_name, "aes")?;
        // Prefix should have already been checked.
        let unprefixed_name = link_name.as_str().strip_prefix("llvm.x86.aesni.").unwrap();

        match unprefixed_name {
            // Used to implement the _mm_aesdec_si128, _mm256_aesdec_epi128
            // and _mm512_aesdec_epi128 functions.
            // Performs one round of an AES decryption on each 128-bit word of
            // `state` with the corresponding 128-bit key of `key`.
            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_aesdec_si128
            "aesdec" | "aesdec.256" | "aesdec.512" => {
                let [state, key] = this.check_shim(abi, Conv::C, link_name, args)?;
                aes_round(this, state, key, dest, |state, key| {
                    let key = aes::Block::from(key.to_le_bytes());
                    let mut state = aes::Block::from(state.to_le_bytes());
                    // `aes::hazmat::equiv_inv_cipher_round` documentation states that
                    // it performs the same operation as the x86 aesdec instruction.
                    aes::hazmat::equiv_inv_cipher_round(&mut state, &key);
                    u128::from_le_bytes(state.into())
                })?;
            }
            // Used to implement the _mm_aesdeclast_si128, _mm256_aesdeclast_epi128
            // and _mm512_aesdeclast_epi128 functions.
            // Performs last round of an AES decryption on each 128-bit word of
            // `state` with the corresponding 128-bit key of `key`.
            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_aesdeclast_si128
            "aesdeclast" | "aesdeclast.256" | "aesdeclast.512" => {
                let [state, key] = this.check_shim(abi, Conv::C, link_name, args)?;

                aes_round(this, state, key, dest, |state, key| {
                    let mut state = aes::Block::from(state.to_le_bytes());
                    // `aes::hazmat::equiv_inv_cipher_round` does the following operations:
                    // state = InvShiftRows(state)
                    // state = InvSubBytes(state)
                    // state = InvMixColumns(state)
                    // state = state ^ key
                    // But we need to skip the InvMixColumns.
                    // First, use a zeroed key to skip the XOR.
                    aes::hazmat::equiv_inv_cipher_round(&mut state, &aes::Block::from([0; 16]));
                    // Then, undo the InvMixColumns with MixColumns.
                    aes::hazmat::mix_columns(&mut state);
                    // Finally, do the XOR.
                    u128::from_le_bytes(state.into()) ^ key
                })?;
            }
            // Used to implement the _mm_aesenc_si128, _mm256_aesenc_epi128
            // and _mm512_aesenc_epi128 functions.
            // Performs one round of an AES encryption on each 128-bit word of
            // `state` with the corresponding 128-bit key of `key`.
            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_aesenc_si128
            "aesenc" | "aesenc.256" | "aesenc.512" => {
                let [state, key] = this.check_shim(abi, Conv::C, link_name, args)?;
                aes_round(this, state, key, dest, |state, key| {
                    let key = aes::Block::from(key.to_le_bytes());
                    let mut state = aes::Block::from(state.to_le_bytes());
                    // `aes::hazmat::cipher_round` documentation states that
                    // it performs the same operation as the x86 aesenc instruction.
                    aes::hazmat::cipher_round(&mut state, &key);
                    u128::from_le_bytes(state.into())
                })?;
            }
            // Used to implement the _mm_aesenclast_si128, _mm256_aesenclast_epi128
            // and _mm512_aesenclast_epi128 functions.
            // Performs last round of an AES encryption on each 128-bit word of
            // `state` with the corresponding 128-bit key of `key`.
            // https://www.intel.com/content/www/us/en/docs/intrinsics-guide/index.html#text=_mm_aesenclast_si128
            "aesenclast" | "aesenclast.256" | "aesenclast.512" => {
                let [state, key] = this.check_shim(abi, Conv::C, link_name, args)?;
                aes_round(this, state, key, dest, |state, key| {
                    let mut state = aes::Block::from(state.to_le_bytes());
                    // `aes::hazmat::cipher_round` does the following operations:
                    // state = ShiftRows(state)
                    // state = SubBytes(state)
                    // state = MixColumns(state)
                    // state = state ^ key
                    // But we need to skip the MixColumns.
                    // First, use a zeroed key to skip the XOR.
                    aes::hazmat::cipher_round(&mut state, &aes::Block::from([0; 16]));
                    // Then, undo the MixColumns with InvMixColumns.
                    aes::hazmat::inv_mix_columns(&mut state);
                    // Finally, do the XOR.
                    u128::from_le_bytes(state.into()) ^ key
                })?;
            }
            // Used to implement the _mm_aesimc_si128 function.
            // Performs the AES InvMixColumns operation on `op`
            "aesimc" => {
                let [op] = this.check_shim(abi, Conv::C, link_name, args)?;
                // Transmute to `u128`
                let op = op.transmute(this.machine.layouts.u128, this)?;
                let dest = dest.transmute(this.machine.layouts.u128, this)?;

                let state = this.read_scalar(&op)?.to_u128()?;
                let mut state = aes::Block::from(state.to_le_bytes());
                aes::hazmat::inv_mix_columns(&mut state);

                this.write_scalar(Scalar::from_u128(u128::from_le_bytes(state.into())), &dest)?;
            }
            // TODO: Implement the `llvm.x86.aesni.aeskeygenassist` when possible
            // with an external crate.
            _ => return interp_ok(EmulateItemResult::NotSupported),
        }
        interp_ok(EmulateItemResult::NeedsReturn)
    }
}

// Performs an AES round (given by `f`) on each 128-bit word of
// `state` with the corresponding 128-bit key of `key`.
fn aes_round<'tcx>(
    ecx: &mut crate::MiriInterpCx<'tcx>,
    state: &OpTy<'tcx>,
    key: &OpTy<'tcx>,
    dest: &MPlaceTy<'tcx>,
    f: impl Fn(u128, u128) -> u128,
) -> InterpResult<'tcx, ()> {
    assert_eq!(dest.layout.size, state.layout.size);
    assert_eq!(dest.layout.size, key.layout.size);

    // Transmute arguments to arrays of `u128`.
    assert_eq!(dest.layout.size.bytes() % 16, 0);
    let len = dest.layout.size.bytes() / 16;

    let u128_array_layout = ecx.layout_of(Ty::new_array(ecx.tcx.tcx, ecx.tcx.types.u128, len))?;

    let state = state.transmute(u128_array_layout, ecx)?;
    let key = key.transmute(u128_array_layout, ecx)?;
    let dest = dest.transmute(u128_array_layout, ecx)?;

    for i in 0..len {
        let state = ecx.read_scalar(&ecx.project_index(&state, i)?)?.to_u128()?;
        let key = ecx.read_scalar(&ecx.project_index(&key, i)?)?.to_u128()?;
        let dest = ecx.project_index(&dest, i)?;

        let res = f(state, key);

        ecx.write_scalar(Scalar::from_u128(res), &dest)?;
    }

    interp_ok(())
}