Skip to main content

rustc_mir_transform/
add_subtyping_projections.rs

1use rustc_middle::mir::visit::MutVisitor;
2use rustc_middle::mir::*;
3use rustc_middle::ty::TyCtxt;
4
5use crate::patch::MirPatch;
6
7pub(super) struct Subtyper;
8
9struct SubTypeChecker<'a, 'tcx> {
10    tcx: TyCtxt<'tcx>,
11    patcher: MirPatch<'tcx>,
12    local_decls: &'a LocalDecls<'tcx>,
13}
14
15impl<'a, 'tcx> MutVisitor<'tcx> for SubTypeChecker<'a, 'tcx> {
16    fn tcx(&self) -> TyCtxt<'tcx> {
17        self.tcx
18    }
19
20    fn visit_assign(
21        &mut self,
22        place: &mut Place<'tcx>,
23        rvalue: &mut Rvalue<'tcx>,
24        location: Location,
25    ) {
26        if rvalue.is_generic_reborrow() {
27            return;
28        }
29        // We don't need to do anything for deref temps as they are
30        // not part of the source code, but used for desugaring purposes.
31        if self.local_decls[place.local].is_deref_temp() {
32            return;
33        }
34        let mut place_ty = place.ty(self.local_decls, self.tcx).ty;
35        let mut rval_ty = rvalue.ty(self.local_decls, self.tcx);
36        // Not erasing this causes `Free Regions` errors in validator,
37        // when rval is `ReStatic`.
38        rval_ty = self.tcx.erase_and_anonymize_regions(rval_ty);
39        place_ty = self.tcx.erase_and_anonymize_regions(place_ty);
40        if place_ty != rval_ty {
41            let temp = self
42                .patcher
43                .new_temp(rval_ty, self.local_decls[place.as_ref().local].source_info.span);
44            let new_place = Place::from(temp);
45            self.patcher.add_assign(location, new_place, rvalue.clone());
46            *rvalue = Rvalue::Cast(CastKind::Subtype, Operand::Move(new_place), place_ty);
47        }
48    }
49}
50
51// Aim here is to do this kind of transformation:
52//
53// let place: place_ty = rval;
54// // gets transformed to
55// let temp: rval_ty = rval;
56// let place: place_ty = temp as place_ty;
57impl<'tcx> crate::MirPass<'tcx> for Subtyper {
58    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
59        let patch = MirPatch::new(body);
60        let mut checker = SubTypeChecker { tcx, patcher: patch, local_decls: &body.local_decls };
61
62        for (bb, data) in body.basic_blocks.as_mut_preserves_cfg().iter_enumerated_mut() {
63            checker.visit_basic_block_data(bb, data);
64        }
65        checker.patcher.apply(body);
66    }
67
68    fn is_required(&self) -> bool {
69        true
70    }
71}