1use std::fmt::Debug;
2
3use rustc_hir::def_id::DefId;
4use rustc_hir::lang_items::LangItem;
5pub use rustc_infer::infer::*;
6use rustc_macros::extension;
7use rustc_middle::arena::ArenaAllocatable;
8use rustc_middle::infer::canonical::{
9 Canonical, CanonicalQueryInput, CanonicalQueryResponse, QueryResponse,
10};
11use rustc_middle::traits::query::NoSolution;
12use rustc_middle::ty::{self, GenericArg, Ty, TyCtxt, TypeFoldable, Upcast};
13use rustc_span::DUMMY_SP;
14use tracing::instrument;
15
16use crate::infer::at::ToTrace;
17use crate::traits::query::evaluate_obligation::InferCtxtExt as _;
18use crate::traits::{self, Obligation, ObligationCause, ObligationCtxt};
19
20#[extension(pub trait InferCtxtExt<'tcx>)]
21impl<'tcx> InferCtxt<'tcx> {
22 fn can_eq<T: ToTrace<'tcx>>(&self, param_env: ty::ParamEnv<'tcx>, a: T, b: T) -> bool {
23 self.probe(|_| {
24 let ocx = ObligationCtxt::new(self);
25 let Ok(()) = ocx.eq(&ObligationCause::dummy(), param_env, a, b) else {
26 return false;
27 };
28 ocx.try_evaluate_obligations().is_empty()
29 })
30 }
31
32 fn type_is_copy_modulo_regions(&self, param_env: ty::ParamEnv<'tcx>, ty: Ty<'tcx>) -> bool {
33 let ty = self.resolve_vars_if_possible(ty);
34 let copy_def_id = self.tcx.require_lang_item(LangItem::Copy, DUMMY_SP);
35 traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, copy_def_id)
36 }
37
38 fn type_is_clone_modulo_regions(&self, param_env: ty::ParamEnv<'tcx>, ty: Ty<'tcx>) -> bool {
39 let ty = self.resolve_vars_if_possible(ty);
40 let clone_def_id = self.tcx.require_lang_item(LangItem::Clone, DUMMY_SP);
41 traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, clone_def_id)
42 }
43
44 fn type_is_use_cloned_modulo_regions(
45 &self,
46 param_env: ty::ParamEnv<'tcx>,
47 ty: Ty<'tcx>,
48 ) -> bool {
49 let ty = self.resolve_vars_if_possible(ty);
50 let use_cloned_def_id = self.tcx.require_lang_item(LangItem::UseCloned, DUMMY_SP);
51 traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, use_cloned_def_id)
52 }
53
54 fn type_is_sized_modulo_regions(&self, param_env: ty::ParamEnv<'tcx>, ty: Ty<'tcx>) -> bool {
55 let lang_item = self.tcx.require_lang_item(LangItem::Sized, DUMMY_SP);
56 traits::type_known_to_meet_bound_modulo_regions(self, param_env, ty, lang_item)
57 }
58
59 #[instrument(level = "debug", skip(self, params), ret)]
88 fn type_implements_trait(
89 &self,
90 trait_def_id: DefId,
91 params: impl IntoIterator<Item: Into<GenericArg<'tcx>>>,
92 param_env: ty::ParamEnv<'tcx>,
93 ) -> traits::EvaluationResult {
94 let trait_ref = ty::TraitRef::new(self.tcx, trait_def_id, params);
95
96 let obligation = traits::Obligation {
97 cause: traits::ObligationCause::dummy(),
98 param_env,
99 recursion_depth: 0,
100 predicate: trait_ref.upcast(self.tcx),
101 };
102 self.evaluate_obligation(&obligation).unwrap_or(traits::EvaluationResult::EvaluatedToErr)
103 }
104
105 fn type_implements_trait_shallow(
114 &self,
115 trait_def_id: DefId,
116 ty: Ty<'tcx>,
117 param_env: ty::ParamEnv<'tcx>,
118 ) -> Option<Vec<traits::FulfillmentError<'tcx>>> {
119 self.probe(|_snapshot| {
120 let ocx = ObligationCtxt::new_with_diagnostics(self);
121 ocx.register_obligation(Obligation::new(
122 self.tcx,
123 ObligationCause::dummy(),
124 param_env,
125 ty::TraitRef::new(self.tcx, trait_def_id, [ty]),
126 ));
127 let errors = ocx.try_evaluate_obligations();
128 for error in &errors {
132 let Some(trait_clause) = error.obligation.predicate.as_trait_clause() else {
133 continue;
134 };
135 let Some(bound_ty) = trait_clause.self_ty().no_bound_vars() else { continue };
136 if trait_clause.def_id() == trait_def_id
137 && ocx.eq(&ObligationCause::dummy(), param_env, bound_ty, ty).is_ok()
138 {
139 return None;
140 }
141 }
142 Some(errors)
143 })
144 }
145}
146
147#[extension(pub trait InferCtxtBuilderExt<'tcx>)]
148impl<'tcx> InferCtxtBuilder<'tcx> {
149 fn enter_canonical_trait_query<K, R>(
166 self,
167 canonical_key: &CanonicalQueryInput<'tcx, K>,
168 operation: impl FnOnce(&ObligationCtxt<'_, 'tcx>, K) -> Result<R, NoSolution>,
169 ) -> Result<CanonicalQueryResponse<'tcx, R>, NoSolution>
170 where
171 K: TypeFoldable<TyCtxt<'tcx>>,
172 R: Debug + TypeFoldable<TyCtxt<'tcx>>,
173 Canonical<'tcx, QueryResponse<'tcx, R>>: ArenaAllocatable<'tcx>,
174 {
175 let (infcx, key, var_values) = self.build_with_canonical(DUMMY_SP, canonical_key);
176 let ocx = ObligationCtxt::new(&infcx);
177 let value = operation(&ocx, key)?;
178 ocx.make_canonicalized_query_response(var_values, value)
179 }
180}