rustc_hir_analysis/variance/
solve.rs1use rustc_hir::def_id::DefIdMap;
9use rustc_middle::ty;
10use tracing::debug;
11
12use super::constraints::*;
13use super::terms::VarianceTerm::*;
14use super::terms::*;
15
16fn glb(v1: ty::Variance, v2: ty::Variance) -> ty::Variance {
17 match (v1, v2) {
23 (ty::Invariant, _) | (_, ty::Invariant) => ty::Invariant,
24
25 (ty::Covariant, ty::Contravariant) => ty::Invariant,
26 (ty::Contravariant, ty::Covariant) => ty::Invariant,
27
28 (ty::Covariant, ty::Covariant) => ty::Covariant,
29
30 (ty::Contravariant, ty::Contravariant) => ty::Contravariant,
31
32 (x, ty::Bivariant) | (ty::Bivariant, x) => x,
33 }
34}
35struct SolveContext<'a, 'tcx> {
36 terms_cx: TermsContext<'a, 'tcx>,
37 constraints: Vec<Constraint<'a>>,
38
39 solutions: Vec<ty::Variance>,
41}
42
43pub(crate) fn solve_constraints<'tcx>(
44 constraints_cx: ConstraintContext<'_, 'tcx>,
45) -> ty::CrateVariancesMap<'tcx> {
46 let ConstraintContext { terms_cx, constraints, .. } = constraints_cx;
47
48 let mut solutions = vec![ty::Bivariant; terms_cx.inferred_terms.len()];
49 for (id, variances) in &terms_cx.lang_items {
50 let InferredIndex(start) = terms_cx.inferred_starts[id];
51 for (i, &variance) in variances.iter().enumerate() {
52 solutions[start + i] = variance;
53 }
54 }
55
56 let mut solutions_cx = SolveContext { terms_cx, constraints, solutions };
57 solutions_cx.solve();
58 let variances = solutions_cx.create_map();
59
60 ty::CrateVariancesMap { variances }
61}
62
63impl<'a, 'tcx> SolveContext<'a, 'tcx> {
64 fn solve(&mut self) {
65 let mut changed = true;
71 while changed {
72 changed = false;
73
74 for constraint in &self.constraints {
75 let Constraint { inferred, variance: term } = *constraint;
76 let InferredIndex(inferred) = inferred;
77 let variance = self.evaluate(term);
78 let old_value = self.solutions[inferred];
79 let new_value = glb(variance, old_value);
80 if old_value != new_value {
81 debug!(
82 "updating inferred {} \
83 from {:?} to {:?} due to {:?}",
84 inferred, old_value, new_value, term
85 );
86
87 self.solutions[inferred] = new_value;
88 changed = true;
89 }
90 }
91 }
92 }
93
94 fn enforce_const_invariance(&self, generics: &ty::Generics, variances: &mut [ty::Variance]) {
95 let tcx = self.terms_cx.tcx;
96
97 for param in generics.own_params.iter() {
99 if let ty::GenericParamDefKind::Const { .. } = param.kind {
100 variances[param.index as usize] = ty::Invariant;
101 }
102 }
103
104 if let Some(def_id) = generics.parent {
106 self.enforce_const_invariance(tcx.generics_of(def_id), variances);
107 }
108 }
109
110 fn create_map(&self) -> DefIdMap<&'tcx [ty::Variance]> {
111 let tcx = self.terms_cx.tcx;
112
113 let solutions = &self.solutions;
114 DefIdMap::from(self.terms_cx.inferred_starts.items().map(
115 |(&def_id, &InferredIndex(start))| {
116 let generics = tcx.generics_of(def_id);
117 let count = generics.count();
118
119 let variances = tcx.arena.alloc_slice(&solutions[start..(start + count)]);
120
121 self.enforce_const_invariance(generics, variances);
123
124 if let ty::FnDef(..) = tcx.type_of(def_id).instantiate_identity().kind() {
126 for variance in variances.iter_mut() {
127 if *variance == ty::Bivariant {
128 *variance = ty::Invariant;
129 }
130 }
131 }
132
133 (def_id.to_def_id(), &*variances)
134 },
135 ))
136 }
137
138 fn evaluate(&self, term: VarianceTermPtr<'a>) -> ty::Variance {
139 match *term {
140 ConstantTerm(v) => v,
141
142 TransformTerm(t1, t2) => {
143 let v1 = self.evaluate(t1);
144 let v2 = self.evaluate(t2);
145 v1.xform(v2)
146 }
147
148 InferredTerm(InferredIndex(index)) => self.solutions[index],
149 }
150 }
151}