rustc_trait_selection/traits/
engine.rs
1use std::cell::RefCell;
2use std::fmt::Debug;
3
4use rustc_data_structures::fx::FxIndexSet;
5use rustc_errors::ErrorGuaranteed;
6use rustc_hir::def_id::{DefId, LocalDefId};
7use rustc_infer::infer::at::ToTrace;
8use rustc_infer::infer::canonical::{
9 Canonical, CanonicalQueryResponse, CanonicalVarValues, QueryResponse,
10};
11use rustc_infer::infer::{DefineOpaqueTypes, InferCtxt, InferOk, RegionResolutionError, TypeTrace};
12use rustc_infer::traits::PredicateObligations;
13use rustc_macros::extension;
14use rustc_middle::arena::ArenaAllocatable;
15use rustc_middle::traits::query::NoSolution;
16use rustc_middle::ty::error::TypeError;
17use rustc_middle::ty::{self, Ty, TyCtxt, TypeFoldable, Upcast, Variance};
18use rustc_type_ir::relate::Relate;
19
20use super::{FromSolverError, FulfillmentContext, ScrubbedTraitError, TraitEngine};
21use crate::error_reporting::InferCtxtErrorExt;
22use crate::regions::InferCtxtRegionExt;
23use crate::solve::{FulfillmentCtxt as NextFulfillmentCtxt, NextSolverError};
24use crate::traits::fulfill::OldSolverError;
25use crate::traits::{
26 FulfillmentError, NormalizeExt, Obligation, ObligationCause, PredicateObligation,
27 StructurallyNormalizeExt,
28};
29
30#[extension(pub trait TraitEngineExt<'tcx, E>)]
31impl<'tcx, E> dyn TraitEngine<'tcx, E>
32where
33 E: FromSolverError<'tcx, NextSolverError<'tcx>> + FromSolverError<'tcx, OldSolverError<'tcx>>,
34{
35 fn new(infcx: &InferCtxt<'tcx>) -> Box<Self> {
36 if infcx.next_trait_solver() {
37 Box::new(NextFulfillmentCtxt::new(infcx))
38 } else {
39 assert!(
40 !infcx.tcx.next_trait_solver_globally(),
41 "using old solver even though new solver is enabled globally"
42 );
43 Box::new(FulfillmentContext::new(infcx))
44 }
45 }
46}
47
48pub struct ObligationCtxt<'a, 'tcx, E = ScrubbedTraitError<'tcx>> {
51 pub infcx: &'a InferCtxt<'tcx>,
52 engine: RefCell<Box<dyn TraitEngine<'tcx, E>>>,
53}
54
55impl<'a, 'tcx> ObligationCtxt<'a, 'tcx, FulfillmentError<'tcx>> {
56 pub fn new_with_diagnostics(infcx: &'a InferCtxt<'tcx>) -> Self {
57 Self { infcx, engine: RefCell::new(<dyn TraitEngine<'tcx, _>>::new(infcx)) }
58 }
59}
60
61impl<'a, 'tcx> ObligationCtxt<'a, 'tcx, ScrubbedTraitError<'tcx>> {
62 pub fn new(infcx: &'a InferCtxt<'tcx>) -> Self {
63 Self { infcx, engine: RefCell::new(<dyn TraitEngine<'tcx, _>>::new(infcx)) }
64 }
65}
66
67impl<'a, 'tcx, E> ObligationCtxt<'a, 'tcx, E>
68where
69 E: 'tcx,
70{
71 pub fn register_obligation(&self, obligation: PredicateObligation<'tcx>) {
72 self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation);
73 }
74
75 pub fn register_obligations(
76 &self,
77 obligations: impl IntoIterator<Item = PredicateObligation<'tcx>>,
78 ) {
79 for obligation in obligations {
82 self.engine.borrow_mut().register_predicate_obligation(self.infcx, obligation)
83 }
84 }
85
86 pub fn register_infer_ok_obligations<T>(&self, infer_ok: InferOk<'tcx, T>) -> T {
87 let InferOk { value, obligations } = infer_ok;
88 self.engine.borrow_mut().register_predicate_obligations(self.infcx, obligations);
89 value
90 }
91
92 pub fn register_bound(
96 &self,
97 cause: ObligationCause<'tcx>,
98 param_env: ty::ParamEnv<'tcx>,
99 ty: Ty<'tcx>,
100 def_id: DefId,
101 ) {
102 let tcx = self.infcx.tcx;
103 let trait_ref = ty::TraitRef::new(tcx, def_id, [ty]);
104 self.register_obligation(Obligation {
105 cause,
106 recursion_depth: 0,
107 param_env,
108 predicate: trait_ref.upcast(tcx),
109 });
110 }
111
112 pub fn normalize<T: TypeFoldable<TyCtxt<'tcx>>>(
113 &self,
114 cause: &ObligationCause<'tcx>,
115 param_env: ty::ParamEnv<'tcx>,
116 value: T,
117 ) -> T {
118 let infer_ok = self.infcx.at(cause, param_env).normalize(value);
119 self.register_infer_ok_obligations(infer_ok)
120 }
121
122 pub fn eq<T: ToTrace<'tcx>>(
123 &self,
124 cause: &ObligationCause<'tcx>,
125 param_env: ty::ParamEnv<'tcx>,
126 expected: T,
127 actual: T,
128 ) -> Result<(), TypeError<'tcx>> {
129 self.infcx
130 .at(cause, param_env)
131 .eq(DefineOpaqueTypes::Yes, expected, actual)
132 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
133 }
134
135 pub fn eq_trace<T: Relate<TyCtxt<'tcx>>>(
136 &self,
137 cause: &ObligationCause<'tcx>,
138 param_env: ty::ParamEnv<'tcx>,
139 trace: TypeTrace<'tcx>,
140 expected: T,
141 actual: T,
142 ) -> Result<(), TypeError<'tcx>> {
143 self.infcx
144 .at(cause, param_env)
145 .eq_trace(DefineOpaqueTypes::Yes, trace, expected, actual)
146 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
147 }
148
149 pub fn sub<T: ToTrace<'tcx>>(
151 &self,
152 cause: &ObligationCause<'tcx>,
153 param_env: ty::ParamEnv<'tcx>,
154 expected: T,
155 actual: T,
156 ) -> Result<(), TypeError<'tcx>> {
157 self.infcx
158 .at(cause, param_env)
159 .sub(DefineOpaqueTypes::Yes, expected, actual)
160 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
161 }
162
163 pub fn relate<T: ToTrace<'tcx>>(
164 &self,
165 cause: &ObligationCause<'tcx>,
166 param_env: ty::ParamEnv<'tcx>,
167 variance: Variance,
168 expected: T,
169 actual: T,
170 ) -> Result<(), TypeError<'tcx>> {
171 self.infcx
172 .at(cause, param_env)
173 .relate(DefineOpaqueTypes::Yes, expected, variance, actual)
174 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
175 }
176
177 pub fn sup<T: ToTrace<'tcx>>(
179 &self,
180 cause: &ObligationCause<'tcx>,
181 param_env: ty::ParamEnv<'tcx>,
182 expected: T,
183 actual: T,
184 ) -> Result<(), TypeError<'tcx>> {
185 self.infcx
186 .at(cause, param_env)
187 .sup(DefineOpaqueTypes::Yes, expected, actual)
188 .map(|infer_ok| self.register_infer_ok_obligations(infer_ok))
189 }
190
191 #[must_use]
192 pub fn select_where_possible(&self) -> Vec<E> {
193 self.engine.borrow_mut().select_where_possible(self.infcx)
194 }
195
196 #[must_use]
197 pub fn select_all_or_error(&self) -> Vec<E> {
198 self.engine.borrow_mut().select_all_or_error(self.infcx)
199 }
200
201 #[must_use]
209 pub fn into_pending_obligations(self) -> PredicateObligations<'tcx> {
210 self.engine.borrow().pending_obligations()
211 }
212
213 pub fn resolve_regions_and_report_errors(
218 self,
219 body_id: LocalDefId,
220 param_env: ty::ParamEnv<'tcx>,
221 assumed_wf_tys: impl IntoIterator<Item = Ty<'tcx>>,
222 ) -> Result<(), ErrorGuaranteed> {
223 let errors = self.infcx.resolve_regions(body_id, param_env, assumed_wf_tys);
224 if errors.is_empty() {
225 Ok(())
226 } else {
227 Err(self.infcx.err_ctxt().report_region_errors(body_id, &errors))
228 }
229 }
230
231 #[must_use]
236 pub fn resolve_regions(
237 self,
238 body_id: LocalDefId,
239 param_env: ty::ParamEnv<'tcx>,
240 assumed_wf_tys: impl IntoIterator<Item = Ty<'tcx>>,
241 ) -> Vec<RegionResolutionError<'tcx>> {
242 self.infcx.resolve_regions(body_id, param_env, assumed_wf_tys)
243 }
244}
245
246impl<'tcx> ObligationCtxt<'_, 'tcx, FulfillmentError<'tcx>> {
247 pub fn assumed_wf_types_and_report_errors(
248 &self,
249 param_env: ty::ParamEnv<'tcx>,
250 def_id: LocalDefId,
251 ) -> Result<FxIndexSet<Ty<'tcx>>, ErrorGuaranteed> {
252 self.assumed_wf_types(param_env, def_id)
253 .map_err(|errors| self.infcx.err_ctxt().report_fulfillment_errors(errors))
254 }
255}
256
257impl<'tcx> ObligationCtxt<'_, 'tcx, ScrubbedTraitError<'tcx>> {
258 pub fn make_canonicalized_query_response<T>(
259 &self,
260 inference_vars: CanonicalVarValues<'tcx>,
261 answer: T,
262 ) -> Result<CanonicalQueryResponse<'tcx, T>, NoSolution>
263 where
264 T: Debug + TypeFoldable<TyCtxt<'tcx>>,
265 Canonical<'tcx, QueryResponse<'tcx, T>>: ArenaAllocatable<'tcx>,
266 {
267 self.infcx.make_canonicalized_query_response(
268 inference_vars,
269 answer,
270 &mut **self.engine.borrow_mut(),
271 )
272 }
273}
274
275impl<'tcx, E> ObligationCtxt<'_, 'tcx, E>
276where
277 E: FromSolverError<'tcx, NextSolverError<'tcx>>,
278{
279 pub fn assumed_wf_types(
280 &self,
281 param_env: ty::ParamEnv<'tcx>,
282 def_id: LocalDefId,
283 ) -> Result<FxIndexSet<Ty<'tcx>>, Vec<E>> {
284 let tcx = self.infcx.tcx;
285 let mut implied_bounds = FxIndexSet::default();
286 let mut errors = Vec::new();
287 for &(ty, span) in tcx.assumed_wf_types(def_id) {
288 let cause = ObligationCause::misc(span, def_id);
301 match self
302 .infcx
303 .at(&cause, param_env)
304 .deeply_normalize(ty, &mut **self.engine.borrow_mut())
305 {
306 Ok(normalized) => drop(implied_bounds.insert(normalized)),
308 Err(normalization_errors) => errors.extend(normalization_errors),
309 };
310 }
311
312 if errors.is_empty() { Ok(implied_bounds) } else { Err(errors) }
313 }
314
315 pub fn deeply_normalize<T: TypeFoldable<TyCtxt<'tcx>>>(
316 &self,
317 cause: &ObligationCause<'tcx>,
318 param_env: ty::ParamEnv<'tcx>,
319 value: T,
320 ) -> Result<T, Vec<E>> {
321 self.infcx.at(cause, param_env).deeply_normalize(value, &mut **self.engine.borrow_mut())
322 }
323
324 pub fn structurally_normalize_ty(
325 &self,
326 cause: &ObligationCause<'tcx>,
327 param_env: ty::ParamEnv<'tcx>,
328 value: Ty<'tcx>,
329 ) -> Result<Ty<'tcx>, Vec<E>> {
330 self.infcx
331 .at(cause, param_env)
332 .structurally_normalize_ty(value, &mut **self.engine.borrow_mut())
333 }
334
335 pub fn structurally_normalize_const(
336 &self,
337 cause: &ObligationCause<'tcx>,
338 param_env: ty::ParamEnv<'tcx>,
339 value: ty::Const<'tcx>,
340 ) -> Result<ty::Const<'tcx>, Vec<E>> {
341 self.infcx
342 .at(cause, param_env)
343 .structurally_normalize_const(value, &mut **self.engine.borrow_mut())
344 }
345
346 pub fn structurally_normalize_term(
347 &self,
348 cause: &ObligationCause<'tcx>,
349 param_env: ty::ParamEnv<'tcx>,
350 value: ty::Term<'tcx>,
351 ) -> Result<ty::Term<'tcx>, Vec<E>> {
352 self.infcx
353 .at(cause, param_env)
354 .structurally_normalize_term(value, &mut **self.engine.borrow_mut())
355 }
356}