rustc_borrowck/type_check/
input_output.rs1use std::assert_matches::assert_matches;
11
12use itertools::Itertools;
13use rustc_hir as hir;
14use rustc_infer::infer::{BoundRegionConversionTime, RegionVariableOrigin};
15use rustc_middle::mir::*;
16use rustc_middle::ty::{self, Ty};
17use rustc_span::Span;
18use tracing::{debug, instrument};
19
20use super::{Locations, TypeChecker};
21use crate::renumber::RegionCtxt;
22use crate::universal_regions::DefiningTy;
23
24impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
25 #[instrument(skip(self), level = "debug")]
28 pub(super) fn check_signature_annotation(&mut self) {
29 let mir_def_id = self.body.source.def_id().expect_local();
30
31 if !self.tcx().is_closure_like(mir_def_id.to_def_id()) {
32 return;
33 }
34
35 let user_provided_poly_sig = self.tcx().closure_user_provided_sig(mir_def_id);
36
37 let user_provided_sig = self.instantiate_canonical(self.body.span, &user_provided_poly_sig);
42 let mut user_provided_sig = self.infcx.instantiate_binder_with_fresh_vars(
43 self.body.span,
44 BoundRegionConversionTime::FnCall,
45 user_provided_sig,
46 );
47
48 if let DefiningTy::CoroutineClosure(_, args) = self.universal_regions.defining_ty {
52 assert_matches!(
53 self.tcx().coroutine_kind(self.tcx().coroutine_for_closure(mir_def_id)),
54 Some(hir::CoroutineKind::Desugared(
55 hir::CoroutineDesugaring::Async | hir::CoroutineDesugaring::Gen,
56 hir::CoroutineSource::Closure
57 )),
58 "this needs to be modified if we're lowering non-async closures"
59 );
60 let args = args.as_coroutine_closure();
63 let tupled_upvars_ty = ty::CoroutineClosureSignature::tupled_upvars_by_closure_kind(
64 self.tcx(),
65 args.kind(),
66 Ty::new_tup(self.tcx(), user_provided_sig.inputs()),
67 args.tupled_upvars_ty(),
68 args.coroutine_captures_by_ref_ty(),
69 self.infcx.next_region_var(RegionVariableOrigin::Misc(self.body.span), || {
70 RegionCtxt::Unknown
71 }),
72 );
73
74 let next_ty_var = || self.infcx.next_ty_var(self.body.span);
75 let output_ty = Ty::new_coroutine(
76 self.tcx(),
77 self.tcx().coroutine_for_closure(mir_def_id),
78 ty::CoroutineArgs::new(
79 self.tcx(),
80 ty::CoroutineArgsParts {
81 parent_args: args.parent_args(),
82 kind_ty: Ty::from_coroutine_closure_kind(self.tcx(), args.kind()),
83 return_ty: user_provided_sig.output(),
84 tupled_upvars_ty,
85 resume_ty: next_ty_var(),
88 yield_ty: next_ty_var(),
89 witness: next_ty_var(),
90 },
91 )
92 .args,
93 );
94
95 user_provided_sig = self.tcx().mk_fn_sig(
96 user_provided_sig.inputs().iter().copied(),
97 output_ty,
98 user_provided_sig.c_variadic,
99 user_provided_sig.safety,
100 user_provided_sig.abi,
101 );
102 }
103
104 let is_coroutine_with_implicit_resume_ty = self.tcx().is_coroutine(mir_def_id.to_def_id())
105 && user_provided_sig.inputs().is_empty();
106
107 for (&user_ty, arg_decl) in user_provided_sig.inputs().iter().zip_eq(
108 self.body
111 .args_iter()
112 .skip(1 + if is_coroutine_with_implicit_resume_ty { 1 } else { 0 })
113 .map(|local| &self.body.local_decls[local]),
114 ) {
115 self.ascribe_user_type_skip_wf(
116 arg_decl.ty,
117 ty::UserType::new(ty::UserTypeKind::Ty(user_ty)),
118 arg_decl.source_info.span,
119 );
120 }
121
122 let output_decl = &self.body.local_decls[RETURN_PLACE];
124 self.ascribe_user_type_skip_wf(
125 output_decl.ty,
126 ty::UserType::new(ty::UserTypeKind::Ty(user_provided_sig.output())),
127 output_decl.source_info.span,
128 );
129 }
130
131 #[instrument(skip(self), level = "debug")]
132 pub(super) fn equate_inputs_and_outputs(&mut self, normalized_inputs_and_output: &[Ty<'tcx>]) {
133 let (&normalized_output_ty, normalized_input_tys) =
134 normalized_inputs_and_output.split_last().unwrap();
135
136 debug!(?normalized_output_ty);
137 debug!(?normalized_input_tys);
138
139 for (argument_index, &normalized_input_ty) in normalized_input_tys.iter().enumerate() {
141 if argument_index + 1 >= self.body.local_decls.len() {
142 self.tcx()
143 .dcx()
144 .span_bug(self.body.span, "found more normalized_input_ty than local_decls");
145 }
146
147 let local = Local::from_usize(argument_index + 1);
149
150 let mir_input_ty = self.body.local_decls[local].ty;
151
152 let mir_input_span = self.body.local_decls[local].source_info.span;
153 self.equate_normalized_input_or_output(
154 normalized_input_ty,
155 mir_input_ty,
156 mir_input_span,
157 );
158 }
159
160 if let Some(mir_yield_ty) = self.body.yield_ty() {
161 let yield_span = self.body.local_decls[RETURN_PLACE].source_info.span;
162 self.equate_normalized_input_or_output(
163 self.universal_regions.yield_ty.unwrap(),
164 mir_yield_ty,
165 yield_span,
166 );
167 }
168
169 if let Some(mir_resume_ty) = self.body.resume_ty() {
170 let yield_span = self.body.local_decls[RETURN_PLACE].source_info.span;
171 self.equate_normalized_input_or_output(
172 self.universal_regions.resume_ty.unwrap(),
173 mir_resume_ty,
174 yield_span,
175 );
176 }
177
178 let mir_output_ty = self.body.local_decls[RETURN_PLACE].ty;
180 let output_span = self.body.local_decls[RETURN_PLACE].source_info.span;
181 self.equate_normalized_input_or_output(normalized_output_ty, mir_output_ty, output_span);
182 }
183
184 #[instrument(skip(self), level = "debug")]
185 fn equate_normalized_input_or_output(&mut self, a: Ty<'tcx>, b: Ty<'tcx>, span: Span) {
186 if let Err(_) =
187 self.eq_types(a, b, Locations::All(span), ConstraintCategory::BoringNoLocation)
188 {
189 let b = self.normalize(b, Locations::All(span));
194
195 if let Err(terr) =
199 self.eq_types(a, b, Locations::All(span), ConstraintCategory::BoringNoLocation)
200 {
201 span_mirbug!(
202 self,
203 Location::START,
204 "equate_normalized_input_or_output: `{:?}=={:?}` failed with `{:?}`",
205 a,
206 b,
207 terr
208 );
209 }
210 }
211 }
212}