1use std::collections::BTreeMap;
2
3use rustc_hir::def_id::DefId;
4use rustc_middle::ty::relate::{
5 self, Relate, RelateResult, TypeRelation, relate_args_with_variances,
6};
7use rustc_middle::ty::{self, RegionVid, Ty, TyCtxt, TypeVisitable};
8
9use super::{ConstraintDirection, PoloniusContext};
10use crate::universal_regions::UniversalRegions;
11
12impl PoloniusContext {
13 pub(crate) fn record_live_region_variance<'tcx>(
15 &mut self,
16 tcx: TyCtxt<'tcx>,
17 universal_regions: &UniversalRegions<'tcx>,
18 value: impl TypeVisitable<TyCtxt<'tcx>> + Relate<TyCtxt<'tcx>>,
19 ) {
20 let mut extractor = VarianceExtractor {
21 tcx,
22 ambient_variance: ty::Variance::Covariant,
23 directions: &mut self.live_region_variances,
24 universal_regions,
25 };
26 extractor.relate(value, value).expect("Can't have a type error relating to itself");
27 }
28}
29
30struct VarianceExtractor<'a, 'tcx> {
34 tcx: TyCtxt<'tcx>,
35 ambient_variance: ty::Variance,
36 directions: &'a mut BTreeMap<RegionVid, ConstraintDirection>,
37 universal_regions: &'a UniversalRegions<'tcx>,
38}
39
40impl<'tcx> VarianceExtractor<'_, 'tcx> {
41 fn record_variance(&mut self, region: ty::Region<'tcx>, variance: ty::Variance) {
42 if region.is_bound() {
52 return;
54 }
55
56 if region.is_erased() {
57 return;
64 }
65
66 let direction = match variance {
67 ty::Covariant => ConstraintDirection::Forward,
68 ty::Contravariant => ConstraintDirection::Backward,
69 ty::Invariant => ConstraintDirection::Bidirectional,
70 ty::Bivariant => {
71 return;
73 }
74 };
75
76 let region = self.universal_regions.to_region_vid(region);
77 self.directions
78 .entry(region)
79 .and_modify(|entry| {
80 if entry != &direction {
84 *entry = ConstraintDirection::Bidirectional;
85 }
86 })
87 .or_insert(direction);
88 }
89}
90
91impl<'tcx> TypeRelation<TyCtxt<'tcx>> for VarianceExtractor<'_, 'tcx> {
92 fn cx(&self) -> TyCtxt<'tcx> {
93 self.tcx
94 }
95
96 fn relate_ty_args(
97 &mut self,
98 a_ty: Ty<'tcx>,
99 _: Ty<'tcx>,
100 def_id: DefId,
101 a_args: ty::GenericArgsRef<'tcx>,
102 b_args: ty::GenericArgsRef<'tcx>,
103 _: impl FnOnce(ty::GenericArgsRef<'tcx>) -> Ty<'tcx>,
104 ) -> RelateResult<'tcx, Ty<'tcx>> {
105 let variances = self.cx().variances_of(def_id);
106 relate_args_with_variances(self, variances, a_args, b_args)?;
107 Ok(a_ty)
108 }
109
110 fn relate_with_variance<T: Relate<TyCtxt<'tcx>>>(
111 &mut self,
112 variance: ty::Variance,
113 _info: ty::VarianceDiagInfo<TyCtxt<'tcx>>,
114 a: T,
115 b: T,
116 ) -> RelateResult<'tcx, T> {
117 let old_ambient_variance = self.ambient_variance;
118 self.ambient_variance = self.ambient_variance.xform(variance);
119 let r = self.relate(a, b)?;
120 self.ambient_variance = old_ambient_variance;
121 Ok(r)
122 }
123
124 fn tys(&mut self, a: Ty<'tcx>, b: Ty<'tcx>) -> RelateResult<'tcx, Ty<'tcx>> {
125 match (&a, &b) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(kind, &*left_val, &*right_val,
::core::option::Option::None);
}
}
};assert_eq!(a, b); relate::structurally_relate_tys(self, a, b)
127 }
128
129 fn regions(
130 &mut self,
131 a: ty::Region<'tcx>,
132 b: ty::Region<'tcx>,
133 ) -> RelateResult<'tcx, ty::Region<'tcx>> {
134 match (&a, &b) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(kind, &*left_val, &*right_val,
::core::option::Option::None);
}
}
};assert_eq!(a, b); self.record_variance(a, self.ambient_variance);
136 Ok(a)
137 }
138
139 fn consts(
140 &mut self,
141 a: ty::Const<'tcx>,
142 b: ty::Const<'tcx>,
143 ) -> RelateResult<'tcx, ty::Const<'tcx>> {
144 match (&a, &b) {
(left_val, right_val) => {
if !(*left_val == *right_val) {
let kind = ::core::panicking::AssertKind::Eq;
::core::panicking::assert_failed(kind, &*left_val, &*right_val,
::core::option::Option::None);
}
}
};assert_eq!(a, b); relate::structurally_relate_consts(self, a, b)
146 }
147
148 fn binders<T>(
149 &mut self,
150 a: ty::Binder<'tcx, T>,
151 _: ty::Binder<'tcx, T>,
152 ) -> RelateResult<'tcx, ty::Binder<'tcx, T>>
153 where
154 T: Relate<TyCtxt<'tcx>>,
155 {
156 self.relate(a.skip_binder(), a.skip_binder())?;
157 Ok(a)
158 }
159}