1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
//! This pass finds basic blocks that are completely equal,
//! and replaces all uses with just one of them.

use std::{collections::hash_map::Entry, hash::Hash, hash::Hasher, iter};

use rustc_data_structures::fx::FxHashMap;
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::*;
use rustc_middle::ty::TyCtxt;

use super::simplify::simplify_cfg;

pub struct DeduplicateBlocks;

impl<'tcx> MirPass<'tcx> for DeduplicateBlocks {
    fn is_enabled(&self, sess: &rustc_session::Session) -> bool {
        sess.mir_opt_level() >= 4
    }

    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
        debug!("Running DeduplicateBlocks on `{:?}`", body.source);
        let duplicates = find_duplicates(body);
        let has_opts_to_apply = !duplicates.is_empty();

        if has_opts_to_apply {
            let mut opt_applier = OptApplier { tcx, duplicates };
            opt_applier.visit_body(body);
            simplify_cfg(body);
        }
    }
}

struct OptApplier<'tcx> {
    tcx: TyCtxt<'tcx>,
    duplicates: FxHashMap<BasicBlock, BasicBlock>,
}

impl<'tcx> MutVisitor<'tcx> for OptApplier<'tcx> {
    fn tcx(&self) -> TyCtxt<'tcx> {
        self.tcx
    }

    fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
        for target in terminator.successors_mut() {
            if let Some(replacement) = self.duplicates.get(target) {
                debug!("SUCCESS: Replacing: `{:?}` with `{:?}`", target, replacement);
                *target = *replacement;
            }
        }

        self.super_terminator(terminator, location);
    }
}

fn find_duplicates(body: &Body<'_>) -> FxHashMap<BasicBlock, BasicBlock> {
    let mut duplicates = FxHashMap::default();

    let bbs_to_go_through =
        body.basic_blocks.iter_enumerated().filter(|(_, bbd)| !bbd.is_cleanup).count();

    let mut same_hashes =
        FxHashMap::with_capacity_and_hasher(bbs_to_go_through, Default::default());

    // Go through the basic blocks backwards. This means that in case of duplicates,
    // we can use the basic block with the highest index as the replacement for all lower ones.
    // For example, if bb1, bb2 and bb3 are duplicates, we will first insert bb3 in same_hashes.
    // Then we will see that bb2 is a duplicate of bb3,
    // and insert bb2 with the replacement bb3 in the duplicates list.
    // When we see bb1, we see that it is a duplicate of bb3, and therefore insert it in the duplicates list
    // with replacement bb3.
    // When the duplicates are removed, we will end up with only bb3.
    for (bb, bbd) in body.basic_blocks.iter_enumerated().rev().filter(|(_, bbd)| !bbd.is_cleanup) {
        // Basic blocks can get really big, so to avoid checking for duplicates in basic blocks
        // that are unlikely to have duplicates, we stop early. The early bail number has been
        // found experimentally by eprintln while compiling the crates in the rustc-perf suite.
        if bbd.statements.len() > 10 {
            continue;
        }

        let to_hash = BasicBlockHashable { basic_block_data: bbd };
        let entry = same_hashes.entry(to_hash);
        match entry {
            Entry::Occupied(occupied) => {
                // The basic block was already in the hashmap, which means we have a duplicate
                let value = *occupied.get();
                debug!("Inserting {:?} -> {:?}", bb, value);
                duplicates.try_insert(bb, value).expect("key was already inserted");
            }
            Entry::Vacant(vacant) => {
                vacant.insert(bb);
            }
        }
    }

    duplicates
}

struct BasicBlockHashable<'tcx, 'a> {
    basic_block_data: &'a BasicBlockData<'tcx>,
}

impl Hash for BasicBlockHashable<'_, '_> {
    fn hash<H: Hasher>(&self, state: &mut H) {
        hash_statements(state, self.basic_block_data.statements.iter());
        // Note that since we only hash the kind, we lose span information if we deduplicate the blocks
        self.basic_block_data.terminator().kind.hash(state);
    }
}

impl Eq for BasicBlockHashable<'_, '_> {}

impl PartialEq for BasicBlockHashable<'_, '_> {
    fn eq(&self, other: &Self) -> bool {
        self.basic_block_data.statements.len() == other.basic_block_data.statements.len()
            && &self.basic_block_data.terminator().kind == &other.basic_block_data.terminator().kind
            && iter::zip(&self.basic_block_data.statements, &other.basic_block_data.statements)
                .all(|(x, y)| statement_eq(&x.kind, &y.kind))
    }
}

fn hash_statements<'a, 'tcx, H: Hasher>(
    hasher: &mut H,
    iter: impl Iterator<Item = &'a Statement<'tcx>>,
) where
    'tcx: 'a,
{
    for stmt in iter {
        statement_hash(hasher, &stmt.kind);
    }
}

fn statement_hash<H: Hasher>(hasher: &mut H, stmt: &StatementKind<'_>) {
    match stmt {
        StatementKind::Assign(box (place, rvalue)) => {
            place.hash(hasher);
            rvalue_hash(hasher, rvalue)
        }
        x => x.hash(hasher),
    };
}

fn rvalue_hash<H: Hasher>(hasher: &mut H, rvalue: &Rvalue<'_>) {
    match rvalue {
        Rvalue::Use(op) => operand_hash(hasher, op),
        x => x.hash(hasher),
    };
}

fn operand_hash<H: Hasher>(hasher: &mut H, operand: &Operand<'_>) {
    match operand {
        Operand::Constant(box ConstOperand { user_ty: _, const_, span: _ }) => const_.hash(hasher),
        x => x.hash(hasher),
    };
}

fn statement_eq<'tcx>(lhs: &StatementKind<'tcx>, rhs: &StatementKind<'tcx>) -> bool {
    let res = match (lhs, rhs) {
        (
            StatementKind::Assign(box (place, rvalue)),
            StatementKind::Assign(box (place2, rvalue2)),
        ) => place == place2 && rvalue_eq(rvalue, rvalue2),
        (x, y) => x == y,
    };
    debug!("statement_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res);
    res
}

fn rvalue_eq<'tcx>(lhs: &Rvalue<'tcx>, rhs: &Rvalue<'tcx>) -> bool {
    let res = match (lhs, rhs) {
        (Rvalue::Use(op1), Rvalue::Use(op2)) => operand_eq(op1, op2),
        (x, y) => x == y,
    };
    debug!("rvalue_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res);
    res
}

fn operand_eq<'tcx>(lhs: &Operand<'tcx>, rhs: &Operand<'tcx>) -> bool {
    let res = match (lhs, rhs) {
        (
            Operand::Constant(box ConstOperand { user_ty: _, const_, span: _ }),
            Operand::Constant(box ConstOperand { user_ty: _, const_: const2, span: _ }),
        ) => const_ == const2,
        (x, y) => x == y,
    };
    debug!("operand_eq lhs: `{:?}` rhs: `{:?}` result: {:?}", lhs, rhs, res);
    res
}