1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315
//! This pass constructs a second coroutine body sufficient for return from
//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures).
//!
//! Consider an async closure like:
//! ```rust
//! #![feature(async_closure)]
//!
//! let x = vec![1, 2, 3];
//!
//! let closure = async move || {
//! println!("{x:#?}");
//! };
//! ```
//!
//! This desugars to something like:
//! ```rust,ignore (invalid-borrowck)
//! let x = vec![1, 2, 3];
//!
//! let closure = move || {
//! async {
//! println!("{x:#?}");
//! }
//! };
//! ```
//!
//! Important to note here is that while the outer closure *moves* `x: Vec<i32>`
//! into its upvars, the inner `async` coroutine simply captures a ref of `x`.
//! This is the "magic" of async closures -- the futures that they return are
//! allowed to borrow from their parent closure's upvars.
//!
//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`,
//! since all async closures implement that too)? Well, recall the signature:
//! ```
//! use std::future::Future;
//! pub trait AsyncFnOnce<Args>
//! {
//! type CallOnceFuture: Future<Output = Self::Output>;
//! type Output;
//! fn async_call_once(
//! self,
//! args: Args
//! ) -> Self::CallOnceFuture;
//! }
//! ```
//!
//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`.
//! How do we deal with the fact that the coroutine is supposed to take a reference
//! to the captured `x` from the parent closure, when that parent closure has been
//! destroyed?
//!
//! This is the second piece of magic of async closures. We can simply create a
//! *second* `async` coroutine body where that `x` that was previously captured
//! by reference is now captured by value. This means that we consume the outer
//! closure and return a new coroutine that will hold onto all of these captures,
//! and drop them when it is finished (i.e. after it has been `.await`ed).
//!
//! We do this with the analysis below, which detects the captures that come from
//! borrowing from the outer closure, and we simply peel off a `deref` projection
//! from them. This second body is stored alongside the first body, and optimized
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
//! we use this "by-move" body instead.
//!
//! ## How does this work?
//!
//! This pass essentially remaps the body of the (child) closure of the coroutine-closure
//! to take the set of upvars of the parent closure by value. This at least requires
//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure
//! captures something by value; however, it may also require renumbering field indices
//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine
//! to split one field capture into two.
use rustc_data_structures::steal::Steal;
use rustc_data_structures::unord::UnordMap;
use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_middle::bug;
use rustc_middle::hir::place::{Projection, ProjectionKind};
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{self, dump_mir};
use rustc_middle::ty::{self, InstanceKind, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::symbol::kw;
use rustc_target::abi::{FieldIdx, VariantIdx};
pub fn coroutine_by_move_body_def_id<'tcx>(
tcx: TyCtxt<'tcx>,
coroutine_def_id: LocalDefId,
) -> DefId {
let body = tcx.mir_built(coroutine_def_id).borrow();
let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
tcx.coroutine_kind(coroutine_def_id)
else {
bug!("should only be invoked on coroutine-closures");
};
// Also, let's skip processing any bodies with errors, since there's no guarantee
// the MIR body will be constructed well.
let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
let args = args.as_coroutine();
let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();
let parent_def_id = tcx.local_parent(coroutine_def_id);
let ty::CoroutineClosure(_, parent_args) =
*tcx.type_of(parent_def_id).instantiate_identity().kind()
else {
bug!();
};
if parent_args.references_error() {
return coroutine_def_id.to_def_id();
}
let parent_closure_args = parent_args.as_coroutine_closure();
let num_args = parent_closure_args
.coroutine_closure_sig()
.skip_binder()
.tupled_inputs_ty
.tuple_fields()
.len();
let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures(
tcx.closure_captures(parent_def_id).iter().copied(),
tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(),
|(parent_field_idx, parent_capture), (child_field_idx, child_capture)| {
// Store this set of additional projections (fields and derefs).
// We need to re-apply them later.
let mut child_precise_captures =
child_capture.place.projections[parent_capture.place.projections.len()..].to_vec();
// If the parent capture is by-ref, then we need to apply an additional
// deref before applying any further projections to this place.
if parent_capture.is_by_ref() {
child_precise_captures.insert(
0,
Projection { ty: parent_capture.place.ty(), kind: ProjectionKind::Deref },
);
}
// If the child capture is by-ref, then we need to apply a "ref"
// projection (i.e. `&`) at the end. But wait! We don't have that
// as a projection kind. So instead, we can apply its dual and
// *peel* a deref off of the place when it shows up in the MIR body.
// Luckily, by construction this is always possible.
let peel_deref = if child_capture.is_by_ref() {
assert!(
parent_capture.is_by_ref() || coroutine_kind != ty::ClosureKind::FnOnce,
"`FnOnce` coroutine-closures return coroutines that capture from \
their body; it will always result in a borrowck error!"
);
true
} else {
false
};
// Regarding the behavior above, you may think that it's redundant to both
// insert a deref and then peel a deref if the parent and child are both
// captured by-ref. This would be correct, except for the case where we have
// precise capturing projections, since the inserted deref is to the *beginning*
// and the peeled deref is at the *end*. I cannot seem to actually find a
// case where this happens, though, but let's keep this code flexible.
// Finally, store the type of the parent's captured place. We need
// this when building the field projection in the MIR body later on.
let mut parent_capture_ty = parent_capture.place.ty();
parent_capture_ty = match parent_capture.info.capture_kind {
ty::UpvarCapture::ByValue => parent_capture_ty,
ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
tcx,
tcx.lifetimes.re_erased,
parent_capture_ty,
kind.to_mutbl_lossy(),
),
};
(
FieldIdx::from_usize(child_field_idx + num_args),
(
FieldIdx::from_usize(parent_field_idx + num_args),
parent_capture_ty,
peel_deref,
child_precise_captures,
),
)
},
)
.collect();
if coroutine_kind == ty::ClosureKind::FnOnce {
assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
// The by-move body is just the body :)
return coroutine_def_id.to_def_id();
}
let by_move_coroutine_ty = tcx
.instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig())
.to_coroutine_given_kind_and_upvars(
tcx,
parent_closure_args.parent_args(),
coroutine_def_id.to_def_id(),
ty::ClosureKind::FnOnce,
tcx.lifetimes.re_erased,
parent_closure_args.tupled_upvars_ty(),
parent_closure_args.coroutine_captures_by_ref_ty(),
);
let mut by_move_body = body.clone();
MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));
let body_def = tcx.create_def(coroutine_def_id, kw::Empty, DefKind::SyntheticCoroutineBody);
by_move_body.source =
mir::MirSource::from_instance(InstanceKind::Item(body_def.def_id().to_def_id()));
// Inherited from the by-ref coroutine.
body_def.codegen_fn_attrs(tcx.codegen_fn_attrs(coroutine_def_id).clone());
body_def.constness(tcx.constness(coroutine_def_id).clone());
body_def.coroutine_kind(tcx.coroutine_kind(coroutine_def_id).clone());
body_def.def_ident_span(tcx.def_ident_span(coroutine_def_id));
body_def.def_span(tcx.def_span(coroutine_def_id));
body_def.explicit_predicates_of(tcx.explicit_predicates_of(coroutine_def_id).clone());
body_def.generics_of(tcx.generics_of(coroutine_def_id).clone());
body_def.param_env(tcx.param_env(coroutine_def_id).clone());
body_def.predicates_of(tcx.predicates_of(coroutine_def_id).clone());
// The type of the coroutine is the `by_move_coroutine_ty`.
body_def.type_of(ty::EarlyBinder::bind(by_move_coroutine_ty));
body_def.mir_built(tcx.arena.alloc(Steal::new(by_move_body)));
body_def.def_id().to_def_id()
}
struct MakeByMoveBody<'tcx> {
tcx: TyCtxt<'tcx>,
field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, Vec<Projection<'tcx>>)>,
by_move_coroutine_ty: Ty<'tcx>,
}
impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
fn tcx(&self) -> TyCtxt<'tcx> {
self.tcx
}
fn visit_place(
&mut self,
place: &mut mir::Place<'tcx>,
context: mir::visit::PlaceContext,
location: mir::Location,
) {
// Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a
// field projection. If this is in `field_remapping`, then it must not be an
// arg from calling the closure, but instead an upvar.
if place.local == ty::CAPTURE_STRUCT_LOCAL
&& let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
place.projection.split_first()
&& let Some(&(remapped_idx, remapped_ty, peel_deref, ref bridging_projections)) =
self.field_remapping.get(&idx)
{
// As noted before, if the parent closure captures a field by value, and
// the child captures a field by ref, then for the by-move body we're
// generating, we also are taking that field by value. Peel off a deref,
// since a layer of ref'ing has now become redundant.
let final_projections = if peel_deref {
let Some((mir::ProjectionElem::Deref, projection)) = projection.split_first()
else {
bug!(
"There should be at least a single deref for an upvar local initialization, found {projection:#?}"
);
};
// There may be more derefs, since we may also implicitly reborrow
// a captured mut pointer.
projection
} else {
projection
};
// These projections are applied in order to "bridge" the local that we are
// currently transforming *from* the old upvar that the by-ref coroutine used
// to capture *to* the upvar of the parent coroutine-closure. For example, if
// the parent captures `&s` but the child captures `&(s.field)`, then we will
// apply a field projection.
let bridging_projections = bridging_projections.iter().map(|elem| match elem.kind {
ProjectionKind::Deref => mir::ProjectionElem::Deref,
ProjectionKind::Field(idx, VariantIdx::ZERO) => {
mir::ProjectionElem::Field(idx, elem.ty)
}
_ => unreachable!("precise captures only through fields and derefs"),
});
// We start out with an adjusted field index (and ty), representing the
// upvar that we get from our parent closure. We apply any of the additional
// projections to make sure that to the rest of the body of the closure, the
// place looks the same, and then apply that final deref if necessary.
*place = mir::Place {
local: place.local,
projection: self.tcx.mk_place_elems_from_iter(
[mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
.into_iter()
.chain(bridging_projections)
.chain(final_projections.iter().copied()),
),
};
}
self.super_place(place, context, location);
}
fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
// Replace the type of the self arg.
if local == ty::CAPTURE_STRUCT_LOCAL {
local_decl.ty = self.by_move_coroutine_ty;
}
}
}