1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
//! This pass constructs a second coroutine body sufficient for return from
//! `FnOnce`/`AsyncFnOnce` implementations for coroutine-closures (e.g. async closures).
//!
//! Consider an async closure like:
//! ```rust
//! #![feature(async_closure)]
//!
//! let x = vec![1, 2, 3];
//!
//! let closure = async move || {
//!     println!("{x:#?}");
//! };
//! ```
//!
//! This desugars to something like:
//! ```rust,ignore (invalid-borrowck)
//! let x = vec![1, 2, 3];
//!
//! let closure = move || {
//!     async {
//!         println!("{x:#?}");
//!     }
//! };
//! ```
//!
//! Important to note here is that while the outer closure *moves* `x: Vec<i32>`
//! into its upvars, the inner `async` coroutine simply captures a ref of `x`.
//! This is the "magic" of async closures -- the futures that they return are
//! allowed to borrow from their parent closure's upvars.
//!
//! However, what happens when we call `closure` with `AsyncFnOnce` (or `FnOnce`,
//! since all async closures implement that too)? Well, recall the signature:
//! ```
//! use std::future::Future;
//! pub trait AsyncFnOnce<Args>
//! {
//!     type CallOnceFuture: Future<Output = Self::Output>;
//!     type Output;
//!     fn async_call_once(
//!         self,
//!         args: Args
//!     ) -> Self::CallOnceFuture;
//! }
//! ```
//!
//! This signature *consumes* the async closure (`self`) and returns a `CallOnceFuture`.
//! How do we deal with the fact that the coroutine is supposed to take a reference
//! to the captured `x` from the parent closure, when that parent closure has been
//! destroyed?
//!
//! This is the second piece of magic of async closures. We can simply create a
//! *second* `async` coroutine body where that `x` that was previously captured
//! by reference is now captured by value. This means that we consume the outer
//! closure and return a new coroutine that will hold onto all of these captures,
//! and drop them when it is finished (i.e. after it has been `.await`ed).
//!
//! We do this with the analysis below, which detects the captures that come from
//! borrowing from the outer closure, and we simply peel off a `deref` projection
//! from them. This second body is stored alongside the first body, and optimized
//! with it in lockstep. When we need to resolve a body for `FnOnce` or `AsyncFnOnce`,
//! we use this "by-move" body instead.
//!
//! ## How does this work?
//!
//! This pass essentially remaps the body of the (child) closure of the coroutine-closure
//! to take the set of upvars of the parent closure by value. This at least requires
//! changing a by-ref upvar to be by-value in the case that the outer coroutine-closure
//! captures something by value; however, it may also require renumbering field indices
//! in case precise captures (edition 2021 closure capture rules) caused the inner coroutine
//! to split one field capture into two.

use rustc_data_structures::steal::Steal;
use rustc_data_structures::unord::UnordMap;
use rustc_hir as hir;
use rustc_hir::def::DefKind;
use rustc_hir::def_id::{DefId, LocalDefId};
use rustc_middle::bug;
use rustc_middle::hir::place::{Projection, ProjectionKind};
use rustc_middle::mir::visit::MutVisitor;
use rustc_middle::mir::{self, dump_mir};
use rustc_middle::ty::{self, InstanceKind, Ty, TyCtxt, TypeVisitableExt};
use rustc_span::symbol::kw;
use rustc_target::abi::{FieldIdx, VariantIdx};

pub fn coroutine_by_move_body_def_id<'tcx>(
    tcx: TyCtxt<'tcx>,
    coroutine_def_id: LocalDefId,
) -> DefId {
    let body = tcx.mir_built(coroutine_def_id).borrow();

    let Some(hir::CoroutineKind::Desugared(_, hir::CoroutineSource::Closure)) =
        tcx.coroutine_kind(coroutine_def_id)
    else {
        bug!("should only be invoked on coroutine-closures");
    };

    // Also, let's skip processing any bodies with errors, since there's no guarantee
    // the MIR body will be constructed well.
    let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;

    let ty::Coroutine(_, args) = *coroutine_ty.kind() else { bug!("{body:#?}") };
    let args = args.as_coroutine();

    let coroutine_kind = args.kind_ty().to_opt_closure_kind().unwrap();

    let parent_def_id = tcx.local_parent(coroutine_def_id);
    let ty::CoroutineClosure(_, parent_args) =
        *tcx.type_of(parent_def_id).instantiate_identity().kind()
    else {
        bug!();
    };
    if parent_args.references_error() {
        return coroutine_def_id.to_def_id();
    }

    let parent_closure_args = parent_args.as_coroutine_closure();
    let num_args = parent_closure_args
        .coroutine_closure_sig()
        .skip_binder()
        .tupled_inputs_ty
        .tuple_fields()
        .len();

    let field_remapping: UnordMap<_, _> = ty::analyze_coroutine_closure_captures(
        tcx.closure_captures(parent_def_id).iter().copied(),
        tcx.closure_captures(coroutine_def_id).iter().skip(num_args).copied(),
        |(parent_field_idx, parent_capture), (child_field_idx, child_capture)| {
            // Store this set of additional projections (fields and derefs).
            // We need to re-apply them later.
            let mut child_precise_captures =
                child_capture.place.projections[parent_capture.place.projections.len()..].to_vec();

            // If the parent capture is by-ref, then we need to apply an additional
            // deref before applying any further projections to this place.
            if parent_capture.is_by_ref() {
                child_precise_captures.insert(
                    0,
                    Projection { ty: parent_capture.place.ty(), kind: ProjectionKind::Deref },
                );
            }
            // If the child capture is by-ref, then we need to apply a "ref"
            // projection (i.e. `&`) at the end. But wait! We don't have that
            // as a projection kind. So instead, we can apply its dual and
            // *peel* a deref off of the place when it shows up in the MIR body.
            // Luckily, by construction this is always possible.
            let peel_deref = if child_capture.is_by_ref() {
                assert!(
                    parent_capture.is_by_ref() || coroutine_kind != ty::ClosureKind::FnOnce,
                    "`FnOnce` coroutine-closures return coroutines that capture from \
                        their body; it will always result in a borrowck error!"
                );
                true
            } else {
                false
            };

            // Regarding the behavior above, you may think that it's redundant to both
            // insert a deref and then peel a deref if the parent and child are both
            // captured by-ref. This would be correct, except for the case where we have
            // precise capturing projections, since the inserted deref is to the *beginning*
            // and the peeled deref is at the *end*. I cannot seem to actually find a
            // case where this happens, though, but let's keep this code flexible.

            // Finally, store the type of the parent's captured place. We need
            // this when building the field projection in the MIR body later on.
            let mut parent_capture_ty = parent_capture.place.ty();
            parent_capture_ty = match parent_capture.info.capture_kind {
                ty::UpvarCapture::ByValue => parent_capture_ty,
                ty::UpvarCapture::ByRef(kind) => Ty::new_ref(
                    tcx,
                    tcx.lifetimes.re_erased,
                    parent_capture_ty,
                    kind.to_mutbl_lossy(),
                ),
            };

            (
                FieldIdx::from_usize(child_field_idx + num_args),
                (
                    FieldIdx::from_usize(parent_field_idx + num_args),
                    parent_capture_ty,
                    peel_deref,
                    child_precise_captures,
                ),
            )
        },
    )
    .collect();

    if coroutine_kind == ty::ClosureKind::FnOnce {
        assert_eq!(field_remapping.len(), tcx.closure_captures(parent_def_id).len());
        // The by-move body is just the body :)
        return coroutine_def_id.to_def_id();
    }

    let by_move_coroutine_ty = tcx
        .instantiate_bound_regions_with_erased(parent_closure_args.coroutine_closure_sig())
        .to_coroutine_given_kind_and_upvars(
            tcx,
            parent_closure_args.parent_args(),
            coroutine_def_id.to_def_id(),
            ty::ClosureKind::FnOnce,
            tcx.lifetimes.re_erased,
            parent_closure_args.tupled_upvars_ty(),
            parent_closure_args.coroutine_captures_by_ref_ty(),
        );

    let mut by_move_body = body.clone();
    MakeByMoveBody { tcx, field_remapping, by_move_coroutine_ty }.visit_body(&mut by_move_body);
    dump_mir(tcx, false, "coroutine_by_move", &0, &by_move_body, |_, _| Ok(()));

    let body_def = tcx.create_def(coroutine_def_id, kw::Empty, DefKind::SyntheticCoroutineBody);
    by_move_body.source =
        mir::MirSource::from_instance(InstanceKind::Item(body_def.def_id().to_def_id()));

    // Inherited from the by-ref coroutine.
    body_def.codegen_fn_attrs(tcx.codegen_fn_attrs(coroutine_def_id).clone());
    body_def.constness(tcx.constness(coroutine_def_id).clone());
    body_def.coroutine_kind(tcx.coroutine_kind(coroutine_def_id).clone());
    body_def.def_ident_span(tcx.def_ident_span(coroutine_def_id));
    body_def.def_span(tcx.def_span(coroutine_def_id));
    body_def.explicit_predicates_of(tcx.explicit_predicates_of(coroutine_def_id).clone());
    body_def.generics_of(tcx.generics_of(coroutine_def_id).clone());
    body_def.param_env(tcx.param_env(coroutine_def_id).clone());
    body_def.predicates_of(tcx.predicates_of(coroutine_def_id).clone());

    // The type of the coroutine is the `by_move_coroutine_ty`.
    body_def.type_of(ty::EarlyBinder::bind(by_move_coroutine_ty));

    body_def.mir_built(tcx.arena.alloc(Steal::new(by_move_body)));

    body_def.def_id().to_def_id()
}

struct MakeByMoveBody<'tcx> {
    tcx: TyCtxt<'tcx>,
    field_remapping: UnordMap<FieldIdx, (FieldIdx, Ty<'tcx>, bool, Vec<Projection<'tcx>>)>,
    by_move_coroutine_ty: Ty<'tcx>,
}

impl<'tcx> MutVisitor<'tcx> for MakeByMoveBody<'tcx> {
    fn tcx(&self) -> TyCtxt<'tcx> {
        self.tcx
    }

    fn visit_place(
        &mut self,
        place: &mut mir::Place<'tcx>,
        context: mir::visit::PlaceContext,
        location: mir::Location,
    ) {
        // Initializing an upvar local always starts with `CAPTURE_STRUCT_LOCAL` and a
        // field projection. If this is in `field_remapping`, then it must not be an
        // arg from calling the closure, but instead an upvar.
        if place.local == ty::CAPTURE_STRUCT_LOCAL
            && let Some((&mir::ProjectionElem::Field(idx, _), projection)) =
                place.projection.split_first()
            && let Some(&(remapped_idx, remapped_ty, peel_deref, ref bridging_projections)) =
                self.field_remapping.get(&idx)
        {
            // As noted before, if the parent closure captures a field by value, and
            // the child captures a field by ref, then for the by-move body we're
            // generating, we also are taking that field by value. Peel off a deref,
            // since a layer of ref'ing has now become redundant.
            let final_projections = if peel_deref {
                let Some((mir::ProjectionElem::Deref, projection)) = projection.split_first()
                else {
                    bug!(
                        "There should be at least a single deref for an upvar local initialization, found {projection:#?}"
                    );
                };
                // There may be more derefs, since we may also implicitly reborrow
                // a captured mut pointer.
                projection
            } else {
                projection
            };

            // These projections are applied in order to "bridge" the local that we are
            // currently transforming *from* the old upvar that the by-ref coroutine used
            // to capture *to* the upvar of the parent coroutine-closure. For example, if
            // the parent captures `&s` but the child captures `&(s.field)`, then we will
            // apply a field projection.
            let bridging_projections = bridging_projections.iter().map(|elem| match elem.kind {
                ProjectionKind::Deref => mir::ProjectionElem::Deref,
                ProjectionKind::Field(idx, VariantIdx::ZERO) => {
                    mir::ProjectionElem::Field(idx, elem.ty)
                }
                _ => unreachable!("precise captures only through fields and derefs"),
            });

            // We start out with an adjusted field index (and ty), representing the
            // upvar that we get from our parent closure. We apply any of the additional
            // projections to make sure that to the rest of the body of the closure, the
            // place looks the same, and then apply that final deref if necessary.
            *place = mir::Place {
                local: place.local,
                projection: self.tcx.mk_place_elems_from_iter(
                    [mir::ProjectionElem::Field(remapped_idx, remapped_ty)]
                        .into_iter()
                        .chain(bridging_projections)
                        .chain(final_projections.iter().copied()),
                ),
            };
        }
        self.super_place(place, context, location);
    }

    fn visit_local_decl(&mut self, local: mir::Local, local_decl: &mut mir::LocalDecl<'tcx>) {
        // Replace the type of the self arg.
        if local == ty::CAPTURE_STRUCT_LOCAL {
            local_decl.ty = self.by_move_coroutine_ty;
        }
    }
}