Skip to main content

rustc_mir_transform/coroutine/
mod.rs

1//! This is the implementation of the pass which transforms coroutines into state machines.
2//!
3//! MIR generation for coroutines creates a function which has a self argument which
4//! passes by value. This argument is effectively a coroutine type which only contains upvars and
5//! is only used for this argument inside the MIR for the coroutine.
6//! It is passed by value to enable upvars to be moved out of it. Drop elaboration runs on that
7//! MIR before this pass and creates drop flags for MIR locals.
8//! It will also drop the coroutine argument (which only consists of upvars) if any of the upvars
9//! are moved out of. This pass elaborates the drops of upvars / coroutine argument in the case
10//! that none of the upvars were moved out of. This is because we cannot have any drops of this
11//! coroutine in the MIR, since it is used to create the drop glue for the coroutine. We'd get
12//! infinite recursion otherwise.
13//!
14//! This pass creates the implementation for either the `Coroutine::resume` or `Future::poll`
15//! function and the drop shim for the coroutine based on the MIR input.
16//! It converts the coroutine argument from Self to &mut Self adding derefs in the MIR as needed.
17//! It computes the final layout of the coroutine struct which looks like this:
18//!     First upvars are stored
19//!     It is followed by the coroutine state field.
20//!     Then finally the MIR locals which are live across a suspension point are stored.
21//!     ```ignore (illustrative)
22//!     struct Coroutine {
23//!         upvars...,
24//!         state: u32,
25//!         mir_locals...,
26//!     }
27//!     ```
28//! This pass computes the meaning of the state field and the MIR locals which are live
29//! across a suspension point. There are however three hardcoded coroutine states:
30//!     0 - Coroutine have not been resumed yet
31//!     1 - Coroutine has returned / is completed
32//!     2 - Coroutine has been poisoned
33//!
34//! It also rewrites `return x` and `yield y` as setting a new coroutine state and returning
35//! `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`,
36//! or `Poll::Ready(x)` and `Poll::Pending` respectively.
37//! MIR locals which are live across a suspension point are moved to the coroutine struct
38//! with references to them being updated with references to the coroutine struct.
39//!
40//! The pass creates two functions which have a switch on the coroutine state giving
41//! the action to take.
42//!
43//! One of them is the implementation of `Coroutine::resume` / `Future::poll`.
44//! For coroutines with state 0 (unresumed) it starts the execution of the coroutine.
45//! For coroutines with state 1 (returned) and state 2 (poisoned) it panics.
46//! Otherwise it continues the execution from the last suspension point.
47//!
48//! The other function is the drop glue for the coroutine.
49//! For coroutines with state 0 (unresumed) it drops the upvars of the coroutine.
50//! For coroutines with state 1 (returned) and state 2 (poisoned) it does nothing.
51//! Otherwise it drops all the values in scope at the last suspension point.
52
53mod by_move_body;
54mod drop;
55mod layout;
56
57pub(super) use by_move_body::coroutine_by_move_body_def_id;
58use drop::{
59    create_coroutine_drop_shim, create_coroutine_drop_shim_async,
60    create_coroutine_drop_shim_proxy_async, elaborate_coroutine_drops, has_async_drops,
61    insert_clean_drop,
62};
63pub(super) use layout::mir_coroutine_witnesses;
64use layout::{CoroutineSavedLocals, compute_layout, locals_live_across_suspend_points};
65use rustc_abi::{FieldIdx, VariantIdx};
66use rustc_hir::lang_items::LangItem;
67use rustc_hir::{self as hir, CoroutineDesugaring, CoroutineKind};
68use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
69use rustc_index::{Idx, IndexVec, indexvec};
70use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
71use rustc_middle::mir::*;
72use rustc_middle::ty::{
73    self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt,
74};
75use rustc_middle::{bug, span_bug};
76use rustc_mir_dataflow::impls::always_storage_live_locals;
77use rustc_span::def_id::DefId;
78use tracing::{debug, instrument};
79
80use crate::deref_separator::deref_finder;
81use crate::{abort_unwinding_calls, pass_manager as pm, simplify};
82
83pub(super) struct StateTransform;
84
85struct RenameLocalVisitor<'tcx> {
86    from: Local,
87    to: Local,
88    tcx: TyCtxt<'tcx>,
89}
90
91impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
92    fn tcx(&self) -> TyCtxt<'tcx> {
93        self.tcx
94    }
95
96    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
97        if *local == self.from {
98            *local = self.to;
99        } else if *local == self.to {
100            *local = self.from;
101        }
102    }
103
104    fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
105        match terminator.kind {
106            TerminatorKind::Return => {
107                // Do not replace the implicit `_0` access here, as that's not possible. The
108                // transform already handles `return` correctly.
109            }
110            _ => self.super_terminator(terminator, location),
111        }
112    }
113}
114
115struct SelfArgVisitor<'tcx> {
116    tcx: TyCtxt<'tcx>,
117    new_base: Place<'tcx>,
118}
119
120impl<'tcx> SelfArgVisitor<'tcx> {
121    fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
122        Self { tcx, new_base }
123    }
124}
125
126impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
127    fn tcx(&self) -> TyCtxt<'tcx> {
128        self.tcx
129    }
130
131    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
132        assert_ne!(*local, SELF_ARG);
133    }
134
135    fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
136        if place.local == SELF_ARG {
137            replace_base(place, self.new_base, self.tcx);
138        }
139
140        for elem in place.projection.iter() {
141            if let PlaceElem::Index(local) = elem {
142                assert_ne!(local, SELF_ARG);
143            }
144        }
145    }
146}
147
148#[tracing::instrument(level = "trace", skip(tcx))]
149fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
150    place.local = new_base.local;
151
152    let mut new_projection = new_base.projection.to_vec();
153    new_projection.append(&mut place.projection.to_vec());
154
155    place.projection = tcx.mk_place_elems(&new_projection);
156    tracing::trace!(?place);
157}
158
159const SELF_ARG: Local = Local::arg(0);
160pub(crate) const CTX_ARG: Local = Local::arg(1);
161
162/// A `yield` point in the coroutine.
163struct SuspensionPoint<'tcx> {
164    /// State discriminant used when suspending or resuming at this point.
165    state: usize,
166    /// The block to jump to after resumption.
167    resume: BasicBlock,
168    /// Where to move the resume argument after resumption.
169    resume_arg: Place<'tcx>,
170    /// Which block to jump to if the coroutine is dropped in this state.
171    drop: Option<BasicBlock>,
172    /// Set of locals that have live storage while at this suspension point.
173    storage_liveness: GrowableBitSet<Local>,
174}
175
176struct TransformVisitor<'tcx> {
177    tcx: TyCtxt<'tcx>,
178    coroutine_kind: hir::CoroutineKind,
179
180    // The type of the discriminant in the coroutine struct
181    discr_ty: Ty<'tcx>,
182
183    // Mapping from Local to (type of local, coroutine struct index)
184    remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
185
186    // A map from a suspension point in a block to the locals which have live storage at that point
187    storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
188
189    // A list of suspension points, generated during the transform
190    suspension_points: Vec<SuspensionPoint<'tcx>>,
191
192    // The set of locals that have no `StorageLive`/`StorageDead` annotations.
193    always_live_locals: DenseBitSet<Local>,
194
195    // New local we just create to hold the `CoroutineState` value.
196    new_ret_local: Local,
197
198    old_yield_ty: Ty<'tcx>,
199
200    old_ret_ty: Ty<'tcx>,
201}
202
203impl<'tcx> TransformVisitor<'tcx> {
204    fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
205        let block = body.basic_blocks.next_index();
206        let source_info = SourceInfo::outermost(body.span);
207
208        let none_value = match self.coroutine_kind {
209            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
210                span_bug!(body.span, "`Future`s are not fused inherently")
211            }
212            CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"),
213            // `gen` continues return `None`
214            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
215                let option_def_id = self.tcx.require_lang_item(LangItem::Option, body.span);
216                make_aggregate_adt(
217                    option_def_id,
218                    VariantIdx::ZERO,
219                    self.tcx.mk_args(&[self.old_yield_ty.into()]),
220                    IndexVec::new(),
221                )
222            }
223            // `async gen` continues to return `Poll::Ready(None)`
224            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
225                let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
226                let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
227                let yield_ty = args.type_at(0);
228                Rvalue::Use(
229                    Operand::Constant(Box::new(ConstOperand {
230                        span: source_info.span,
231                        const_: Const::Unevaluated(
232                            UnevaluatedConst::new(
233                                self.tcx.require_lang_item(LangItem::AsyncGenFinished, body.span),
234                                self.tcx.mk_args(&[yield_ty.into()]),
235                            ),
236                            self.old_yield_ty,
237                        ),
238                        user_ty: None,
239                    })),
240                    WithRetag::Yes,
241                )
242            }
243        };
244
245        let statements = vec![Statement::new(
246            source_info,
247            StatementKind::Assign(Box::new((Place::return_place(), none_value))),
248        )];
249
250        body.basic_blocks_mut().push(BasicBlockData::new_stmts(
251            statements,
252            Some(Terminator { source_info, kind: TerminatorKind::Return }),
253            false,
254        ));
255
256        block
257    }
258
259    // Make a `CoroutineState` or `Poll` variant assignment.
260    //
261    // `core::ops::CoroutineState` only has single element tuple variants,
262    // so we can just write to the downcasted first field and then set the
263    // discriminant to the appropriate variant.
264    #[tracing::instrument(level = "trace", skip(self, statements))]
265    fn make_state(
266        &self,
267        val: Operand<'tcx>,
268        source_info: SourceInfo,
269        is_return: bool,
270        statements: &mut Vec<Statement<'tcx>>,
271    ) {
272        const ZERO: VariantIdx = VariantIdx::ZERO;
273        const ONE: VariantIdx = VariantIdx::from_usize(1);
274        let rvalue = match self.coroutine_kind {
275            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
276                let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, source_info.span);
277                let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
278                let (variant_idx, operands) = if is_return {
279                    (ZERO, indexvec![val]) // Poll::Ready(val)
280                } else {
281                    (ONE, IndexVec::new()) // Poll::Pending
282                };
283                make_aggregate_adt(poll_def_id, variant_idx, args, operands)
284            }
285            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
286                let option_def_id = self.tcx.require_lang_item(LangItem::Option, source_info.span);
287                let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
288                let (variant_idx, operands) = if is_return {
289                    (ZERO, IndexVec::new()) // None
290                } else {
291                    (ONE, indexvec![val]) // Some(val)
292                };
293                make_aggregate_adt(option_def_id, variant_idx, args, operands)
294            }
295            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
296                if is_return {
297                    let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
298                    let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
299                    let yield_ty = args.type_at(0);
300                    Rvalue::Use(
301                        Operand::Constant(Box::new(ConstOperand {
302                            span: source_info.span,
303                            const_: Const::Unevaluated(
304                                UnevaluatedConst::new(
305                                    self.tcx.require_lang_item(
306                                        LangItem::AsyncGenFinished,
307                                        source_info.span,
308                                    ),
309                                    self.tcx.mk_args(&[yield_ty.into()]),
310                                ),
311                                self.old_yield_ty,
312                            ),
313                            user_ty: None,
314                        })),
315                        WithRetag::Yes,
316                    )
317                } else {
318                    Rvalue::Use(val, WithRetag::Yes)
319                }
320            }
321            CoroutineKind::Coroutine(_) => {
322                let coroutine_state_def_id =
323                    self.tcx.require_lang_item(LangItem::CoroutineState, source_info.span);
324                let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
325                let variant_idx = if is_return {
326                    ONE // CoroutineState::Complete(val)
327                } else {
328                    ZERO // CoroutineState::Yielded(val)
329                };
330                make_aggregate_adt(coroutine_state_def_id, variant_idx, args, indexvec![val])
331            }
332        };
333
334        // Assign to `new_ret_local`, which will be replaced by `RETURN_PLACE` later.
335        statements.push(Statement::new(
336            source_info,
337            StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
338        ));
339    }
340
341    // Create a Place referencing a coroutine struct field
342    #[tracing::instrument(level = "trace", skip(self), ret)]
343    fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
344        let self_place = Place::from(SELF_ARG);
345        let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
346        let mut projection = base.projection.to_vec();
347        projection.push(ProjectionElem::Field(idx, ty));
348
349        Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
350    }
351
352    // Create a statement which changes the discriminant
353    #[tracing::instrument(level = "trace", skip(self))]
354    fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
355        let self_place = Place::from(SELF_ARG);
356        Statement::new(
357            source_info,
358            StatementKind::SetDiscriminant {
359                place: Box::new(self_place),
360                variant_index: state_disc,
361            },
362        )
363    }
364
365    // Create a statement which reads the discriminant into a temporary
366    #[tracing::instrument(level = "trace", skip(self, body))]
367    fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
368        let temp_decl = LocalDecl::new(self.discr_ty, body.span);
369        let local_decls_len = body.local_decls.push(temp_decl);
370        let temp = Place::from(local_decls_len);
371
372        let self_place = Place::from(SELF_ARG);
373        let assign = Statement::new(
374            SourceInfo::outermost(body.span),
375            StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))),
376        );
377        (assign, temp)
378    }
379
380    /// Swaps all references of `old_local` and `new_local`.
381    #[tracing::instrument(level = "trace", skip(self, body))]
382    fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
383        body.local_decls.swap(old_local, new_local);
384
385        let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
386        visitor.visit_body(body);
387        for suspension in &mut self.suspension_points {
388            let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
389            let location = Location { block: START_BLOCK, statement_index: 0 };
390            visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
391        }
392    }
393}
394
395impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
396    fn tcx(&self) -> TyCtxt<'tcx> {
397        self.tcx
398    }
399
400    #[tracing::instrument(level = "trace", skip(self), ret)]
401    fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
402        assert!(!self.remap.contains(*local));
403    }
404
405    #[tracing::instrument(level = "trace", skip(self), ret)]
406    fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
407        // Replace an Local in the remap with a coroutine struct access
408        if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
409            replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
410        }
411    }
412
413    #[tracing::instrument(level = "trace", skip(self, stmt), ret)]
414    fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
415        // Remove StorageLive and StorageDead statements for remapped locals
416        if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
417            && self.remap.contains(l)
418        {
419            stmt.make_nop(true);
420        }
421        self.super_statement(stmt, location);
422    }
423
424    #[tracing::instrument(level = "trace", skip(self, term), ret)]
425    fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
426        if let TerminatorKind::Return = term.kind {
427            // `visit_basic_block_data` introduces `Return` terminators which read `RETURN_PLACE`.
428            // But this `RETURN_PLACE` is already remapped, so we should not touch it again.
429            return;
430        }
431        self.super_terminator(term, location);
432    }
433
434    #[tracing::instrument(level = "trace", skip(self, data), ret)]
435    fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
436        match data.terminator().kind {
437            TerminatorKind::Return => {
438                let source_info = data.terminator().source_info;
439                // We must assign the value first in case it gets declared dead below
440                self.make_state(
441                    Operand::Move(Place::return_place()),
442                    source_info,
443                    true,
444                    &mut data.statements,
445                );
446                // Return state.
447                let state = VariantIdx::new(CoroutineArgs::RETURNED);
448                data.statements.push(self.set_discr(state, source_info));
449                data.terminator_mut().kind = TerminatorKind::Return;
450            }
451            TerminatorKind::Yield { ref value, resume, mut resume_arg, drop } => {
452                let source_info = data.terminator().source_info;
453                // We must assign the value first in case it gets declared dead below
454                self.make_state(value.clone(), source_info, false, &mut data.statements);
455                // Yield state.
456                let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
457
458                // The resume arg target location might itself be remapped if its base local is
459                // live across a yield.
460                if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
461                    replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
462                }
463
464                let storage_liveness: GrowableBitSet<Local> =
465                    self.storage_liveness[block].clone().unwrap().into();
466
467                for i in 0..self.always_live_locals.domain_size() {
468                    let l = Local::new(i);
469                    let needs_storage_dead = storage_liveness.contains(l)
470                        && !self.remap.contains(l)
471                        && !self.always_live_locals.contains(l);
472                    if needs_storage_dead {
473                        data.statements
474                            .push(Statement::new(source_info, StatementKind::StorageDead(l)));
475                    }
476                }
477
478                self.suspension_points.push(SuspensionPoint {
479                    state,
480                    resume,
481                    resume_arg,
482                    drop,
483                    storage_liveness,
484                });
485
486                let state = VariantIdx::new(state);
487                data.statements.push(self.set_discr(state, source_info));
488                data.terminator_mut().kind = TerminatorKind::Return;
489            }
490            _ => {}
491        }
492
493        self.super_basic_block_data(block, data);
494    }
495}
496
497fn make_aggregate_adt<'tcx>(
498    def_id: DefId,
499    variant_idx: VariantIdx,
500    args: GenericArgsRef<'tcx>,
501    operands: IndexVec<FieldIdx, Operand<'tcx>>,
502) -> Rvalue<'tcx> {
503    Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
504}
505
506#[tracing::instrument(level = "trace", skip(tcx, body))]
507fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
508    let coroutine_ty = body.local_decls[SELF_ARG].ty;
509
510    let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
511
512    // Replace the by value coroutine argument
513    body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
514
515    // Add a deref to accesses of the coroutine state
516    SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
517}
518
519#[tracing::instrument(level = "trace", skip(tcx, body))]
520fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
521    let coroutine_ty = body.local_decls[SELF_ARG].ty;
522
523    let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
524
525    let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
526    let pin_adt_ref = tcx.adt_def(pin_did);
527    let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
528    let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
529
530    // Replace the by ref coroutine argument
531    body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
532
533    let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
534
535    // Add the Pin field access to accesses of the coroutine state
536    SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
537
538    let source_info = SourceInfo::outermost(body.span);
539    let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);
540
541    let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
542    statements.insert(
543        0,
544        Statement::new(
545            source_info,
546            StatementKind::Assign(Box::new((
547                unpinned_local.into(),
548                Rvalue::Use(Operand::Copy(pin_field), WithRetag::Yes),
549            ))),
550        ),
551    );
552}
553
554/// Transforms the `body` of the coroutine applying the following transforms:
555///
556/// - Eliminates all the `get_context` calls that async lowering created.
557/// - Replace all `Local` `ResumeTy` types with `&mut Context<'_>` (`context_mut_ref`).
558///
559/// The `Local`s that have their types replaced are:
560/// - The `resume` argument itself.
561/// - The argument to `get_context`.
562/// - The yielded value of a `yield`.
563///
564/// The `ResumeTy` hides a `&mut Context<'_>` behind an unsafe raw pointer, and the
565/// `get_context` function is being used to convert that back to a `&mut Context<'_>`.
566///
567/// Ideally the async lowering would not use the `ResumeTy`/`get_context` indirection,
568/// but rather directly use `&mut Context<'_>`, however that would currently
569/// lead to higher-kinded lifetime errors.
570/// See <https://github.com/rust-lang/rust/issues/105501>.
571///
572/// The async lowering step and the type / lifetime inference / checking are
573/// still using the `ResumeTy` indirection for the time being, and that indirection
574/// is removed here. After this transform, the coroutine body only knows about `&mut Context<'_>`.
575#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
576fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
577    let context_mut_ref = Ty::new_task_context(tcx);
578
579    // replace the type of the `resume` argument
580    replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
581
582    let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
583
584    for bb in body.basic_blocks.indices() {
585        let bb_data = &body[bb];
586        if bb_data.is_cleanup {
587            continue;
588        }
589
590        match &bb_data.terminator().kind {
591            TerminatorKind::Call { func, .. } => {
592                let func_ty = func.ty(body, tcx);
593                if let ty::FnDef(def_id, _) = *func_ty.kind()
594                    && def_id == get_context_def_id
595                {
596                    let local = eliminate_get_context_call(&mut body[bb]);
597                    replace_resume_ty_local(tcx, body, local, context_mut_ref);
598                }
599            }
600            TerminatorKind::Yield { resume_arg, .. } => {
601                replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
602            }
603            _ => {}
604        }
605    }
606}
607
608fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
609    let terminator = bb_data.terminator.take().unwrap();
610    let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
611        bug!();
612    };
613    let [arg] = *Box::try_from(args).unwrap();
614    let local = arg.node.place().unwrap().local;
615
616    let arg = Rvalue::Use(arg.node, WithRetag::Yes);
617    let assign =
618        Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
619    bb_data.statements.push(assign);
620    bb_data.terminator = Some(Terminator {
621        source_info: terminator.source_info,
622        kind: TerminatorKind::Goto { target: target.unwrap() },
623    });
624    local
625}
626
627#[cfg_attr(not(debug_assertions), allow(unused))]
628#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
629fn replace_resume_ty_local<'tcx>(
630    tcx: TyCtxt<'tcx>,
631    body: &mut Body<'tcx>,
632    local: Local,
633    context_mut_ref: Ty<'tcx>,
634) {
635    let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
636    // We have to replace the `ResumeTy` that is used for type and borrow checking
637    // with `&mut Context<'_>` in MIR.
638    #[cfg(debug_assertions)]
639    {
640        if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
641            let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
642            assert_eq!(*resume_ty_adt, expected_adt);
643        } else {
644            panic!("expected `ResumeTy`, found `{:?}`", local_ty);
645        };
646    }
647}
648
649/// Transforms the `body` of the coroutine applying the following transform:
650///
651/// - Remove the `resume` argument.
652///
653/// Ideally the async lowering would not add the `resume` argument.
654///
655/// The async lowering step and the type / lifetime inference / checking are
656/// still using the `resume` argument for the time being. After this transform,
657/// the coroutine body doesn't have the `resume` argument.
658fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) {
659    // This leaves the local representing the `resume` argument in place,
660    // but turns it into a regular local variable. This is cheaper than
661    // adjusting all local references in the body after removing it.
662    body.arg_count = 1;
663}
664
665/// Replaces the entry point of `body` with a block that switches on the coroutine discriminant and
666/// dispatches to blocks according to `cases`.
667///
668/// After this function, the former entry point of the function will be the last block.
669fn insert_switch<'tcx>(
670    body: &mut Body<'tcx>,
671    cases: Vec<(usize, BasicBlock)>,
672    transform: &TransformVisitor<'tcx>,
673    default_block: BasicBlock,
674) {
675    let (assign, discr) = transform.get_discr(body);
676
677    // MIR validation ensures that no block targets `ENTRY_BLOCK`.
678    #[cfg(debug_assertions)]
679    for bb in body.basic_blocks.iter() {
680        for target in bb.terminator().successors() {
681            assert_ne!(target, START_BLOCK);
682        }
683    }
684
685    // Add the switch as entry block, and put the former entry block at the end.
686    let former_entry = std::mem::replace(
687        &mut body.basic_blocks_mut()[START_BLOCK],
688        BasicBlockData::new_stmts(vec![assign], None, false),
689    );
690    let former_entry = body.basic_blocks_mut().push(former_entry);
691
692    // We may point to `START_BLOCK` in our `cases`, replace it with `former_entry`.
693    let mut switch_targets =
694        SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
695    for bb in switch_targets.all_targets_mut() {
696        if *bb == START_BLOCK {
697            *bb = former_entry;
698        }
699    }
700
701    let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
702    body.basic_blocks_mut()[START_BLOCK].terminator =
703        Some(Terminator { source_info: SourceInfo::outermost(body.span), kind: switch });
704}
705
706fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock {
707    let source_info = SourceInfo::outermost(body.span);
708    body.basic_blocks_mut().push(BasicBlockData::new(Some(Terminator { source_info, kind }), false))
709}
710
711fn return_poll_ready_assign<'tcx>(tcx: TyCtxt<'tcx>, source_info: SourceInfo) -> Statement<'tcx> {
712    // Poll::Ready(())
713    let poll_def_id = tcx.require_lang_item(LangItem::Poll, source_info.span);
714    let args = tcx.mk_args(&[tcx.types.unit.into()]);
715    let val = Operand::Constant(Box::new(ConstOperand {
716        span: source_info.span,
717        user_ty: None,
718        const_: Const::zero_sized(tcx.types.unit),
719    }));
720    let ready_val = Rvalue::Aggregate(
721        Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::from_usize(0), args, None, None)),
722        indexvec![val],
723    );
724    Statement::new(source_info, StatementKind::Assign(Box::new((Place::return_place(), ready_val))))
725}
726
727fn insert_poll_ready_block<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> BasicBlock {
728    let source_info = SourceInfo::outermost(body.span);
729    body.basic_blocks_mut().push(BasicBlockData::new_stmts(
730        [return_poll_ready_assign(tcx, source_info)].to_vec(),
731        Some(Terminator { source_info, kind: TerminatorKind::Return }),
732        false,
733    ))
734}
735
736fn insert_panic_block<'tcx>(
737    tcx: TyCtxt<'tcx>,
738    body: &mut Body<'tcx>,
739    message: AssertMessage<'tcx>,
740) -> BasicBlock {
741    let assert_block = body.basic_blocks.next_index();
742    let kind = TerminatorKind::Assert {
743        cond: Operand::Constant(Box::new(ConstOperand {
744            span: body.span,
745            user_ty: None,
746            const_: Const::from_bool(tcx, false),
747        })),
748        expected: true,
749        msg: Box::new(message),
750        target: assert_block,
751        unwind: UnwindAction::Continue,
752    };
753
754    insert_term_block(body, kind)
755}
756
757fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::TypingEnv<'tcx>) -> bool {
758    // Returning from a function with an uninhabited return type is undefined behavior.
759    if body.return_ty().is_privately_uninhabited(tcx, typing_env) {
760        return false;
761    }
762
763    // If there's a return terminator the function may return.
764    body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
765    // Otherwise the function can't return.
766}
767
768fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
769    // Nothing can unwind when landing pads are off.
770    if !tcx.sess.panic_strategy().unwinds() {
771        return false;
772    }
773
774    // If we don't find an unwinding terminator, the function cannot unwind.
775    body.basic_blocks.iter().any(|block| block.terminator().unwind().is_some())
776}
777
778// Poison the coroutine when it unwinds
779fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
780    transform: &TransformVisitor<'tcx>,
781    body: &mut Body<'tcx>,
782) {
783    let source_info = SourceInfo::outermost(body.span);
784    let poison_block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
785        vec![transform.set_discr(VariantIdx::new(CoroutineArgs::POISONED), source_info)],
786        Some(Terminator { source_info, kind: TerminatorKind::UnwindResume }),
787        true,
788    ));
789
790    for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() {
791        let source_info = block.terminator().source_info;
792
793        if let TerminatorKind::UnwindResume = block.terminator().kind {
794            // An existing `Resume` terminator is redirected to jump to our dedicated
795            // "poisoning block" above.
796            if idx != poison_block {
797                *block.terminator_mut() =
798                    Terminator { source_info, kind: TerminatorKind::Goto { target: poison_block } };
799            }
800        } else if !block.is_cleanup
801            // Any terminators that *can* unwind but don't have an unwind target set are also
802            // pointed at our poisoning block (unless they're part of the cleanup path).
803            && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
804        {
805            *unwind = UnwindAction::Cleanup(poison_block);
806        }
807    }
808}
809
810#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
811fn create_coroutine_resume_function<'tcx>(
812    tcx: TyCtxt<'tcx>,
813    transform: TransformVisitor<'tcx>,
814    body: &mut Body<'tcx>,
815    can_return: bool,
816    can_unwind: bool,
817) {
818    // Poison the coroutine when it unwinds
819    if can_unwind {
820        generate_poison_block_and_redirect_unwinds_there(&transform, body);
821    }
822
823    let mut cases = create_cases(body, &transform, Operation::Resume);
824
825    use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
826
827    // Jump to the entry point on the unresumed
828    cases.insert(0, (CoroutineArgs::UNRESUMED, START_BLOCK));
829
830    // Panic when resumed on the returned or poisoned state
831    if can_unwind {
832        cases.insert(
833            1,
834            (
835                CoroutineArgs::POISONED,
836                insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind)),
837            ),
838        );
839    }
840
841    if can_return {
842        let block = match transform.coroutine_kind {
843            CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
844            | CoroutineKind::Coroutine(_) => {
845                // For `async_drop_in_place<T>::{closure}` we just keep return Poll::Ready,
846                // because async drop of such coroutine keeps polling original coroutine
847                if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) {
848                    insert_poll_ready_block(tcx, body)
849                } else {
850                    insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
851                }
852            }
853            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
854            | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
855                transform.insert_none_ret_block(body)
856            }
857        };
858        cases.insert(1, (CoroutineArgs::RETURNED, block));
859    }
860
861    let default_block = insert_term_block(body, TerminatorKind::Unreachable);
862    insert_switch(body, cases, &transform, default_block);
863
864    match transform.coroutine_kind {
865        CoroutineKind::Coroutine(_)
866        | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
867        {
868            make_coroutine_state_argument_pinned(tcx, body);
869        }
870        // Iterator::next doesn't accept a pinned argument,
871        // unlike for all other coroutine kinds.
872        CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
873            make_coroutine_state_argument_indirect(tcx, body);
874        }
875    }
876
877    // Make sure we remove dead blocks to remove
878    // unrelated code from the drop part of the function
879    simplify::remove_dead_blocks(body);
880
881    pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);
882
883    // Run derefer to fix Derefs that are not in the first place
884    deref_finder(tcx, body, false);
885
886    if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
887        dumper.dump_mir(body);
888    }
889}
890
891/// An operation that can be performed on a coroutine.
892#[derive(PartialEq, Copy, Clone, Debug)]
893enum Operation {
894    Resume,
895    Drop,
896    AsyncDrop,
897}
898
899impl Operation {
900    fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> {
901        match self {
902            Operation::Resume => Some(point.resume),
903            Operation::Drop | Operation::AsyncDrop => point.drop,
904        }
905    }
906
907    fn resume_place<'tcx>(self, point: &SuspensionPoint<'tcx>) -> Option<Place<'tcx>> {
908        match self {
909            Operation::Resume | Operation::AsyncDrop => Some(point.resume_arg),
910            Operation::Drop => None,
911        }
912    }
913}
914
915#[tracing::instrument(level = "trace", skip(transform, body))]
916fn create_cases<'tcx>(
917    body: &mut Body<'tcx>,
918    transform: &TransformVisitor<'tcx>,
919    operation: Operation,
920) -> Vec<(usize, BasicBlock)> {
921    let source_info = SourceInfo::outermost(body.span);
922
923    transform
924        .suspension_points
925        .iter()
926        .filter_map(|point| {
927            // Find the target for this suspension point, if applicable
928            operation.target_block(point).map(|target| {
929                let mut statements = Vec::new();
930
931                // Create StorageLive instructions for locals with live storage
932                for l in body.local_decls.indices() {
933                    let needs_storage_live = point.storage_liveness.contains(l)
934                        && !transform.remap.contains(l)
935                        && !transform.always_live_locals.contains(l);
936                    if needs_storage_live {
937                        statements.push(Statement::new(source_info, StatementKind::StorageLive(l)));
938                    }
939                }
940
941                // Move the resume argument to the destination place of the `Yield` terminator
942                if let Some(resume_arg) = operation.resume_place(point)
943                    && resume_arg != CTX_ARG.into()
944                {
945                    statements.push(Statement::new(
946                        source_info,
947                        StatementKind::Assign(Box::new((
948                            resume_arg,
949                            Rvalue::Use(Operand::Move(CTX_ARG.into()), WithRetag::Yes),
950                        ))),
951                    ));
952                }
953
954                // Then jump to the real target
955                let block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
956                    statements,
957                    Some(Terminator { source_info, kind: TerminatorKind::Goto { target } }),
958                    false,
959                ));
960
961                (point.state, block)
962            })
963        })
964        .collect()
965}
966
967impl<'tcx> crate::MirPass<'tcx> for StateTransform {
968    #[instrument(level = "debug", skip(self, tcx, body), ret)]
969    fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
970        debug!(def_id = ?body.source.def_id());
971
972        let Some(old_yield_ty) = body.yield_ty() else {
973            // This only applies to coroutines
974            return;
975        };
976        tracing::trace!(def_id = ?body.source.def_id());
977
978        let old_ret_ty = body.return_ty();
979
980        assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
981
982        if let Some(dumper) = MirDumper::new(tcx, "coroutine_before", body) {
983            dumper.dump_mir(body);
984        }
985
986        // The first argument is the coroutine type passed by value
987        let coroutine_ty = body.local_decls.raw[1].ty;
988        let coroutine_kind = body.coroutine_kind().unwrap();
989
990        // Get the discriminant type and args which typeck computed
991        let ty::Coroutine(_, args) = coroutine_ty.kind() else {
992            tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
993        };
994        let discr_ty = args.as_coroutine().discr_ty(tcx);
995
996        let new_ret_ty = match coroutine_kind {
997            CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
998                // Compute Poll<return_ty>
999                let poll_did = tcx.require_lang_item(LangItem::Poll, body.span);
1000                let poll_adt_ref = tcx.adt_def(poll_did);
1001                let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
1002                Ty::new_adt(tcx, poll_adt_ref, poll_args)
1003            }
1004            CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1005                // Compute Option<yield_ty>
1006                let option_did = tcx.require_lang_item(LangItem::Option, body.span);
1007                let option_adt_ref = tcx.adt_def(option_did);
1008                let option_args = tcx.mk_args(&[old_yield_ty.into()]);
1009                Ty::new_adt(tcx, option_adt_ref, option_args)
1010            }
1011            CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
1012                // The yield ty is already `Poll<Option<yield_ty>>`
1013                old_yield_ty
1014            }
1015            CoroutineKind::Coroutine(_) => {
1016                // Compute CoroutineState<yield_ty, return_ty>
1017                let state_did = tcx.require_lang_item(LangItem::CoroutineState, body.span);
1018                let state_adt_ref = tcx.adt_def(state_did);
1019                let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
1020                Ty::new_adt(tcx, state_adt_ref, state_args)
1021            }
1022        };
1023
1024        // We need to insert clean drop for unresumed state and perform drop elaboration
1025        // (finally in open_drop_for_tuple) before async drop expansion.
1026        // Async drops, produced by this drop elaboration, will be expanded,
1027        // and corresponding futures kept in layout.
1028        let coroutine_is_async = coroutine_kind.is_async_desugaring();
1029        let has_async_drops = has_async_drops(body);
1030
1031        // Replace all occurrences of `ResumeTy` with `&mut Context<'_>` within async bodies.
1032        if coroutine_is_async {
1033            transform_async_context(tcx, body);
1034        }
1035
1036        let always_live_locals = always_storage_live_locals(body);
1037        let movable = coroutine_kind.movability() == hir::Movability::Movable;
1038        let liveness_info =
1039            locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1040
1041        if tcx.sess.opts.unstable_opts.validate_mir {
1042            let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias {
1043                assigned_local: None,
1044                saved_locals: &liveness_info.saved_locals,
1045                storage_conflicts: &liveness_info.storage_conflicts,
1046            };
1047
1048            vis.visit_body(body);
1049        }
1050
1051        // Extract locals which are live across suspension point into `layout`
1052        // `remap` gives a mapping from local indices onto coroutine struct indices
1053        // `storage_liveness` tells us which locals have live storage at suspension points
1054        let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
1055
1056        let can_return = can_return(tcx, body, body.typing_env(tcx));
1057
1058        // We rename RETURN_PLACE which has type mir.return_ty to new_ret_local
1059        // RETURN_PLACE then is a fresh unused local with type ret_ty.
1060        let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1061        tracing::trace!(?new_ret_local);
1062
1063        // Run the transformation which converts Places from Local to coroutine struct
1064        // accesses for locals in `remap`.
1065        // It also rewrites `return x` and `yield y` as writing a new coroutine state and returning
1066        // either `CoroutineState::Complete(x)` and `CoroutineState::Yielded(y)`,
1067        // or `Poll::Ready(x)` and `Poll::Pending` respectively depending on the coroutine kind.
1068        let mut transform = TransformVisitor {
1069            tcx,
1070            coroutine_kind,
1071            remap,
1072            storage_liveness,
1073            always_live_locals,
1074            suspension_points: Vec::new(),
1075            discr_ty,
1076            new_ret_local,
1077            old_ret_ty,
1078            old_yield_ty,
1079        };
1080        transform.visit_body(body);
1081
1082        // Swap the actual `RETURN_PLACE` and the provisional `new_ret_local`.
1083        transform.replace_local(RETURN_PLACE, new_ret_local, body);
1084
1085        // MIR parameters are not explicitly assigned-to when entering the MIR body.
1086        // If we want to save their values inside the coroutine state, we need to do so explicitly.
1087        let source_info = SourceInfo::outermost(body.span);
1088        let args_iter = body.args_iter();
1089        body.basic_blocks.as_mut()[START_BLOCK].statements.splice(
1090            0..0,
1091            args_iter.filter_map(|local| {
1092                let (ty, variant_index, idx) = transform.remap[local]?;
1093                let lhs = transform.make_field(variant_index, idx, ty);
1094                let rhs = Rvalue::Use(Operand::Move(local.into()), WithRetag::Yes);
1095                let assign = StatementKind::Assign(Box::new((lhs, rhs)));
1096                Some(Statement::new(source_info, assign))
1097            }),
1098        );
1099
1100        // Update our MIR struct to reflect the changes we've made
1101        body.arg_count = 2; // self, resume arg
1102        body.spread_arg = None;
1103
1104        // Remove the context argument within generator bodies.
1105        if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1106            transform_gen_context(body);
1107        }
1108
1109        // The original arguments to the function are no longer arguments, mark them as such.
1110        // Otherwise they'll conflict with our new arguments, which although they don't have
1111        // argument_index set, will get emitted as unnamed arguments.
1112        for var in &mut body.var_debug_info {
1113            var.argument_index = None;
1114        }
1115
1116        body.coroutine.as_mut().unwrap().yield_ty = None;
1117        body.coroutine.as_mut().unwrap().resume_ty = None;
1118        body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
1119
1120        // Insert `drop(coroutine_struct)` which is used to drop upvars for coroutines in
1121        // the unresumed state.
1122        // This is expanded to a drop ladder in `elaborate_coroutine_drops`.
1123        let drop_clean = insert_clean_drop(tcx, body, has_async_drops);
1124
1125        if let Some(dumper) = MirDumper::new(tcx, "coroutine_pre-elab", body) {
1126            dumper.dump_mir(body);
1127        }
1128
1129        // Expand `drop(coroutine_struct)` to a drop ladder which destroys upvars.
1130        // If any upvars are moved out of, drop elaboration will handle upvar destruction.
1131        // However we need to also elaborate the code generated by `insert_clean_drop`.
1132        elaborate_coroutine_drops(tcx, body);
1133
1134        if let Some(dumper) = MirDumper::new(tcx, "coroutine_post-transform", body) {
1135            dumper.dump_mir(body);
1136        }
1137
1138        let can_unwind = can_unwind(tcx, body);
1139
1140        // Create a copy of our MIR and use it to create the drop shim for the coroutine
1141        if has_async_drops {
1142            // If coroutine has async drops, generating async drop shim
1143            let drop_shim =
1144                create_coroutine_drop_shim_async(tcx, &transform, body, drop_clean, can_unwind);
1145            body.coroutine.as_mut().unwrap().coroutine_drop_async = Some(drop_shim);
1146        } else {
1147            // If coroutine has no async drops, generating sync drop shim
1148            let drop_shim =
1149                create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);
1150            body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
1151
1152            // For coroutine with sync drop, generating async proxy for `future_drop_poll` call
1153            let proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
1154            body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
1155        }
1156
1157        // Create the Coroutine::resume / Future::poll function
1158        create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind);
1159    }
1160
1161    fn is_required(&self) -> bool {
1162        true
1163    }
1164}
1165
1166/// Looks for any assignments between locals (e.g., `_4 = _5`) that will both be converted to fields
1167/// in the coroutine state machine but whose storage is not marked as conflicting
1168///
1169/// Validation needs to happen immediately *before* `TransformVisitor` is invoked, not after.
1170///
1171/// This condition would arise when the assignment is the last use of `_5` but the initial
1172/// definition of `_4` if we weren't extra careful to mark all locals used inside a statement as
1173/// conflicting. Non-conflicting coroutine saved locals may be stored at the same location within
1174/// the coroutine state machine, which would result in ill-formed MIR: the left-hand and right-hand
1175/// sides of an assignment may not alias. This caused a miscompilation in [#73137].
1176///
1177/// [#73137]: https://github.com/rust-lang/rust/issues/73137
1178struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> {
1179    saved_locals: &'a CoroutineSavedLocals,
1180    storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
1181    assigned_local: Option<CoroutineSavedLocal>,
1182}
1183
1184impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1185    fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> {
1186        if place.is_indirect() {
1187            return None;
1188        }
1189
1190        self.saved_locals.get(place.local)
1191    }
1192
1193    fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) {
1194        if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
1195            assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
1196
1197            self.assigned_local = Some(assigned_local);
1198            f(self);
1199            self.assigned_local = None;
1200        }
1201    }
1202}
1203
1204impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1205    fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
1206        let Some(lhs) = self.assigned_local else {
1207            // This visitor only invokes `visit_place` for the right-hand side of an assignment
1208            // and only after setting `self.assigned_local`. However, the default impl of
1209            // `Visitor::super_body` may call `visit_place` with a `NonUseContext` for places
1210            // with debuginfo. Ignore them here.
1211            assert!(!context.is_use());
1212            return;
1213        };
1214
1215        let Some(rhs) = self.saved_local_for_direct_place(*place) else { return };
1216
1217        if !self.storage_conflicts.contains(lhs, rhs) {
1218            bug!(
1219                "Assignment between coroutine saved locals whose storage is not \
1220                    marked as conflicting: {:?}: {:?} = {:?}",
1221                location,
1222                lhs,
1223                rhs,
1224            );
1225        }
1226    }
1227
1228    fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
1229        match &statement.kind {
1230            StatementKind::Assign((lhs, rhs)) => {
1231                self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
1232            }
1233
1234            StatementKind::FakeRead(..)
1235            | StatementKind::SetDiscriminant { .. }
1236            | StatementKind::StorageLive(_)
1237            | StatementKind::StorageDead(_)
1238            | StatementKind::AscribeUserType(..)
1239            | StatementKind::PlaceMention(..)
1240            | StatementKind::Coverage(..)
1241            | StatementKind::Intrinsic(..)
1242            | StatementKind::ConstEvalCounter
1243            | StatementKind::BackwardIncompatibleDropHint { .. }
1244            | StatementKind::Nop => {}
1245        }
1246    }
1247
1248    fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
1249        // Checking for aliasing in terminators is probably overkill, but until we have actual
1250        // semantics, we should be conservative here.
1251        match &terminator.kind {
1252            TerminatorKind::Call {
1253                func,
1254                args,
1255                destination,
1256                target: Some(_),
1257                unwind: _,
1258                call_source: _,
1259                fn_span: _,
1260            } => {
1261                self.check_assigned_place(*destination, |this| {
1262                    this.visit_operand(func, location);
1263                    for arg in args {
1264                        this.visit_operand(&arg.node, location);
1265                    }
1266                });
1267            }
1268
1269            TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
1270                self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
1271            }
1272
1273            // FIXME: Does `asm!` have any aliasing requirements?
1274            TerminatorKind::InlineAsm { .. } => {}
1275
1276            TerminatorKind::Call { .. }
1277            | TerminatorKind::Goto { .. }
1278            | TerminatorKind::SwitchInt { .. }
1279            | TerminatorKind::UnwindResume
1280            | TerminatorKind::UnwindTerminate(_)
1281            | TerminatorKind::Return
1282            | TerminatorKind::TailCall { .. }
1283            | TerminatorKind::Unreachable
1284            | TerminatorKind::Drop { .. }
1285            | TerminatorKind::Assert { .. }
1286            | TerminatorKind::CoroutineDrop
1287            | TerminatorKind::FalseEdge { .. }
1288            | TerminatorKind::FalseUnwind { .. } => {}
1289        }
1290    }
1291}