rustc_hir_analysis/
hir_wf_check.rs

1use rustc_hir::intravisit::{self, Visitor, VisitorExt};
2use rustc_hir::{self as hir, AmbigArg, ForeignItem, ForeignItemKind};
3use rustc_infer::infer::TyCtxtInferExt;
4use rustc_infer::traits::{ObligationCause, WellFormedLoc};
5use rustc_middle::bug;
6use rustc_middle::query::Providers;
7use rustc_middle::ty::{self, TyCtxt, TypingMode, fold_regions};
8use rustc_span::def_id::LocalDefId;
9use rustc_trait_selection::traits::{self, ObligationCtxt};
10use tracing::debug;
11
12use crate::collect::ItemCtxt;
13
14pub(crate) fn provide(providers: &mut Providers) {
15    *providers = Providers { diagnostic_hir_wf_check, ..*providers };
16}
17
18// Ideally, this would be in `rustc_trait_selection`, but we
19// need access to `ItemCtxt`
20fn diagnostic_hir_wf_check<'tcx>(
21    tcx: TyCtxt<'tcx>,
22    (predicate, loc): (ty::Predicate<'tcx>, WellFormedLoc),
23) -> Option<ObligationCause<'tcx>> {
24    let def_id = match loc {
25        WellFormedLoc::Ty(def_id) => def_id,
26        WellFormedLoc::Param { function, param_idx: _ } => function,
27    };
28    let hir_id = tcx.local_def_id_to_hir_id(def_id);
29
30    // HIR wfcheck should only ever happen as part of improving an existing error
31    tcx.dcx()
32        .span_delayed_bug(tcx.def_span(def_id), "Performed HIR wfcheck without an existing error!");
33
34    let icx = ItemCtxt::new(tcx, def_id);
35
36    // To perform HIR-based WF checking, we iterate over all HIR types
37    // that occur 'inside' the item we're checking. For example,
38    // given the type `Option<MyStruct<u8>>`, we will check
39    // `Option<MyStruct<u8>>`, `MyStruct<u8>`, and `u8`.
40    // For each type, we perform a well-formed check, and see if we get
41    // an error that matches our expected predicate. We save
42    // the `ObligationCause` corresponding to the *innermost* type,
43    // which is the most specific type that we can point to.
44    // In general, the different components of an `hir::Ty` may have
45    // completely different spans due to macro invocations. Pointing
46    // to the most accurate part of the type can be the difference
47    // between a useless span (e.g. the macro invocation site)
48    // and a useful span (e.g. a user-provided type passed into the macro).
49    //
50    // This approach is quite inefficient - we redo a lot of work done
51    // by the normal WF checker. However, this code is run at most once
52    // per reported error - it will have no impact when compilation succeeds,
53    // and should only have an impact if a very large number of errors is
54    // displayed to the user.
55    struct HirWfCheck<'tcx> {
56        tcx: TyCtxt<'tcx>,
57        predicate: ty::Predicate<'tcx>,
58        cause: Option<ObligationCause<'tcx>>,
59        cause_depth: usize,
60        icx: ItemCtxt<'tcx>,
61        def_id: LocalDefId,
62        param_env: ty::ParamEnv<'tcx>,
63        depth: usize,
64    }
65
66    impl<'tcx> Visitor<'tcx> for HirWfCheck<'tcx> {
67        fn visit_ty(&mut self, ty: &'tcx hir::Ty<'tcx, AmbigArg>) {
68            let infcx = self.tcx.infer_ctxt().build(TypingMode::non_body_analysis());
69            let ocx = ObligationCtxt::new_with_diagnostics(&infcx);
70
71            // We don't handle infer vars but we wouldn't handle them anyway as we're creating a
72            // fresh `InferCtxt` in this function.
73            let tcx_ty = self.icx.lower_ty(ty.as_unambig_ty());
74            // This visitor can walk into binders, resulting in the `tcx_ty` to
75            // potentially reference escaping bound variables. We simply erase
76            // those here.
77            let tcx_ty = fold_regions(self.tcx, tcx_ty, |r, _| {
78                if r.is_bound() { self.tcx.lifetimes.re_erased } else { r }
79            });
80            let cause = traits::ObligationCause::new(
81                ty.span,
82                self.def_id,
83                traits::ObligationCauseCode::WellFormed(None),
84            );
85
86            ocx.register_obligation(traits::Obligation::new(
87                self.tcx,
88                cause,
89                self.param_env,
90                ty::PredicateKind::Clause(ty::ClauseKind::WellFormed(tcx_ty.into())),
91            ));
92
93            for error in ocx.select_all_or_error() {
94                debug!("Wf-check got error for {:?}: {:?}", ty, error);
95                if error.obligation.predicate == self.predicate {
96                    // Save the cause from the greatest depth - this corresponds
97                    // to picking more-specific types (e.g. `MyStruct<u8>`)
98                    // over less-specific types (e.g. `Option<MyStruct<u8>>`)
99                    if self.depth >= self.cause_depth {
100                        self.cause = Some(error.obligation.cause);
101                        self.cause_depth = self.depth
102                    }
103                }
104            }
105
106            self.depth += 1;
107            intravisit::walk_ty(self, ty);
108            self.depth -= 1;
109        }
110    }
111
112    let mut visitor = HirWfCheck {
113        tcx,
114        predicate,
115        cause: None,
116        cause_depth: 0,
117        icx,
118        def_id,
119        param_env: tcx.param_env(def_id.to_def_id()),
120        depth: 0,
121    };
122
123    // Get the starting `hir::Ty` using our `WellFormedLoc`.
124    // We will walk 'into' this type to try to find
125    // a more precise span for our predicate.
126    let tys = match loc {
127        WellFormedLoc::Ty(_) => match tcx.hir_node(hir_id) {
128            hir::Node::ImplItem(item) => match item.kind {
129                hir::ImplItemKind::Type(ty) => vec![ty],
130                hir::ImplItemKind::Const(ty, _) => vec![ty],
131                ref item => bug!("Unexpected ImplItem {:?}", item),
132            },
133            hir::Node::TraitItem(item) => match item.kind {
134                hir::TraitItemKind::Type(_, ty) => ty.into_iter().collect(),
135                hir::TraitItemKind::Const(ty, _) => vec![ty],
136                ref item => bug!("Unexpected TraitItem {:?}", item),
137            },
138            hir::Node::Item(item) => match item.kind {
139                hir::ItemKind::TyAlias(_, ty, _)
140                | hir::ItemKind::Static(_, ty, _, _)
141                | hir::ItemKind::Const(_, ty, _, _) => vec![ty],
142                hir::ItemKind::Impl(impl_) => match &impl_.of_trait {
143                    Some(t) => t
144                        .path
145                        .segments
146                        .last()
147                        .iter()
148                        .flat_map(|seg| seg.args().args)
149                        .filter_map(|arg| {
150                            if let hir::GenericArg::Type(ty) = arg {
151                                Some(ty.as_unambig_ty())
152                            } else {
153                                None
154                            }
155                        })
156                        .chain([impl_.self_ty])
157                        .collect(),
158                    None => {
159                        vec![impl_.self_ty]
160                    }
161                },
162                ref item => bug!("Unexpected item {:?}", item),
163            },
164            hir::Node::Field(field) => vec![field.ty],
165            hir::Node::ForeignItem(ForeignItem {
166                kind: ForeignItemKind::Static(ty, _, _), ..
167            }) => vec![*ty],
168            hir::Node::GenericParam(hir::GenericParam {
169                kind: hir::GenericParamKind::Type { default: Some(ty), .. },
170                ..
171            }) => vec![*ty],
172            hir::Node::AnonConst(_) => {
173                if let Some(const_param_id) = tcx.hir().opt_const_param_default_param_def_id(hir_id)
174                    && let hir::Node::GenericParam(hir::GenericParam {
175                        kind: hir::GenericParamKind::Const { ty, .. },
176                        ..
177                    }) = tcx.hir_node_by_def_id(const_param_id)
178                {
179                    vec![*ty]
180                } else {
181                    vec![]
182                }
183            }
184            ref node => bug!("Unexpected node {:?}", node),
185        },
186        WellFormedLoc::Param { function: _, param_idx } => {
187            let fn_decl = tcx.hir_fn_decl_by_hir_id(hir_id).unwrap();
188            // Get return type
189            if param_idx as usize == fn_decl.inputs.len() {
190                match fn_decl.output {
191                    hir::FnRetTy::Return(ty) => vec![ty],
192                    // The unit type `()` is always well-formed
193                    hir::FnRetTy::DefaultReturn(_span) => vec![],
194                }
195            } else {
196                vec![&fn_decl.inputs[param_idx as usize]]
197            }
198        }
199    };
200    for ty in tys {
201        visitor.visit_ty_unambig(ty);
202    }
203    visitor.cause
204}