rustc_mir_transform/coroutine/
drop.rs

1//! Drops and async drops related logic for coroutine transformation pass
2
3use super::*;
4
5// Fix return Poll<Rv>::Pending statement into Poll<()>::Pending for async drop function
6struct FixReturnPendingVisitor<'tcx> {
7    tcx: TyCtxt<'tcx>,
8}
9
10impl<'tcx> MutVisitor<'tcx> for FixReturnPendingVisitor<'tcx> {
11    fn tcx(&self) -> TyCtxt<'tcx> {
12        self.tcx
13    }
14
15    fn visit_assign(
16        &mut self,
17        place: &mut Place<'tcx>,
18        rvalue: &mut Rvalue<'tcx>,
19        _location: Location,
20    ) {
21        if place.local != RETURN_PLACE {
22            return;
23        }
24
25        // Converting `_0 = Poll::<Rv>::Pending` to `_0 = Poll::<()>::Pending`
26        if let Rvalue::Aggregate(kind, _) = rvalue
27            && let AggregateKind::Adt(_, _, ref mut args, _, _) = **kind
28        {
29            *args = self.tcx.mk_args(&[self.tcx.types.unit.into()]);
30        }
31    }
32}
33
34// rv = call fut.poll()
35fn build_poll_call<'tcx>(
36    tcx: TyCtxt<'tcx>,
37    body: &mut Body<'tcx>,
38    poll_unit_place: &Place<'tcx>,
39    switch_block: BasicBlock,
40    fut_pin_place: &Place<'tcx>,
41    fut_ty: Ty<'tcx>,
42    context_ref_place: &Place<'tcx>,
43    unwind: UnwindAction,
44) -> BasicBlock {
45    let poll_fn = tcx.require_lang_item(LangItem::FuturePoll, DUMMY_SP);
46    let poll_fn = Ty::new_fn_def(tcx, poll_fn, [fut_ty]);
47    let poll_fn = Operand::Constant(Box::new(ConstOperand {
48        span: DUMMY_SP,
49        user_ty: None,
50        const_: Const::zero_sized(poll_fn),
51    }));
52    let call = TerminatorKind::Call {
53        func: poll_fn.clone(),
54        args: [
55            dummy_spanned(Operand::Move(*fut_pin_place)),
56            dummy_spanned(Operand::Move(*context_ref_place)),
57        ]
58        .into(),
59        destination: *poll_unit_place,
60        target: Some(switch_block),
61        unwind,
62        call_source: CallSource::Misc,
63        fn_span: DUMMY_SP,
64    };
65    insert_term_block(body, call)
66}
67
68// pin_fut = Pin::new_unchecked(&mut fut)
69fn build_pin_fut<'tcx>(
70    tcx: TyCtxt<'tcx>,
71    body: &mut Body<'tcx>,
72    fut_place: Place<'tcx>,
73    unwind: UnwindAction,
74) -> (BasicBlock, Place<'tcx>) {
75    let span = body.span;
76    let source_info = SourceInfo::outermost(span);
77    let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
78    let fut_ref_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, fut_ty);
79    let fut_ref_place = Place::from(body.local_decls.push(LocalDecl::new(fut_ref_ty, span)));
80    let pin_fut_new_unchecked_fn =
81        Ty::new_fn_def(tcx, tcx.require_lang_item(LangItem::PinNewUnchecked, span), [fut_ref_ty]);
82    let fut_pin_ty = pin_fut_new_unchecked_fn.fn_sig(tcx).output().skip_binder();
83    let fut_pin_place = Place::from(body.local_decls.push(LocalDecl::new(fut_pin_ty, span)));
84    let pin_fut_new_unchecked_fn = Operand::Constant(Box::new(ConstOperand {
85        span,
86        user_ty: None,
87        const_: Const::zero_sized(pin_fut_new_unchecked_fn),
88    }));
89
90    let storage_live = Statement::new(source_info, StatementKind::StorageLive(fut_pin_place.local));
91
92    let fut_ref_assign = Statement::new(
93        source_info,
94        StatementKind::Assign(Box::new((
95            fut_ref_place,
96            Rvalue::Ref(
97                tcx.lifetimes.re_erased,
98                BorrowKind::Mut { kind: MutBorrowKind::Default },
99                fut_place,
100            ),
101        ))),
102    );
103
104    // call Pin<FutTy>::new_unchecked(&mut fut)
105    let pin_fut_bb = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
106        [storage_live, fut_ref_assign].to_vec(),
107        Some(Terminator {
108            source_info,
109            kind: TerminatorKind::Call {
110                func: pin_fut_new_unchecked_fn,
111                args: [dummy_spanned(Operand::Move(fut_ref_place))].into(),
112                destination: fut_pin_place,
113                target: None, // will be fixed later
114                unwind,
115                call_source: CallSource::Misc,
116                fn_span: span,
117            },
118        }),
119        false,
120    ));
121    (pin_fut_bb, fut_pin_place)
122}
123
124// Build Poll switch for async drop
125// match rv {
126//     Ready() => ready_block
127//     Pending => yield_block
128//}
129#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
130fn build_poll_switch<'tcx>(
131    tcx: TyCtxt<'tcx>,
132    body: &mut Body<'tcx>,
133    poll_enum: Ty<'tcx>,
134    poll_unit_place: &Place<'tcx>,
135    fut_pin_place: &Place<'tcx>,
136    ready_block: BasicBlock,
137    yield_block: BasicBlock,
138) -> BasicBlock {
139    let poll_enum_adt = poll_enum.ty_adt_def().unwrap();
140
141    let Discr { val: poll_ready_discr, ty: poll_discr_ty } = poll_enum
142        .discriminant_for_variant(
143            tcx,
144            poll_enum_adt
145                .variant_index_with_id(tcx.require_lang_item(LangItem::PollReady, DUMMY_SP)),
146        )
147        .unwrap();
148    let poll_pending_discr = poll_enum
149        .discriminant_for_variant(
150            tcx,
151            poll_enum_adt
152                .variant_index_with_id(tcx.require_lang_item(LangItem::PollPending, DUMMY_SP)),
153        )
154        .unwrap()
155        .val;
156    let source_info = SourceInfo::outermost(body.span);
157    let poll_discr_place =
158        Place::from(body.local_decls.push(LocalDecl::new(poll_discr_ty, source_info.span)));
159    let discr_assign = Statement::new(
160        source_info,
161        StatementKind::Assign(Box::new((poll_discr_place, Rvalue::Discriminant(*poll_unit_place)))),
162    );
163    let storage_dead = Statement::new(source_info, StatementKind::StorageDead(fut_pin_place.local));
164    let unreachable_block = insert_term_block(body, TerminatorKind::Unreachable);
165    body.basic_blocks_mut().push(BasicBlockData::new_stmts(
166        [storage_dead, discr_assign].to_vec(),
167        Some(Terminator {
168            source_info,
169            kind: TerminatorKind::SwitchInt {
170                discr: Operand::Move(poll_discr_place),
171                targets: SwitchTargets::new(
172                    [(poll_ready_discr, ready_block), (poll_pending_discr, yield_block)]
173                        .into_iter(),
174                    unreachable_block,
175                ),
176            },
177        }),
178        false,
179    ))
180}
181
182// Gather blocks, reachable through 'drop' targets of Yield and Drop terminators (chained)
183#[tracing::instrument(level = "trace", skip(body), ret)]
184fn gather_dropline_blocks<'tcx>(body: &mut Body<'tcx>) -> DenseBitSet<BasicBlock> {
185    let mut dropline: DenseBitSet<BasicBlock> = DenseBitSet::new_empty(body.basic_blocks.len());
186    for (bb, data) in traversal::reverse_postorder(body) {
187        if dropline.contains(bb) {
188            data.terminator().successors().for_each(|v| {
189                dropline.insert(v);
190            });
191        } else {
192            match data.terminator().kind {
193                TerminatorKind::Yield { drop: Some(v), .. } => {
194                    dropline.insert(v);
195                }
196                TerminatorKind::Drop { drop: Some(v), .. } => {
197                    dropline.insert(v);
198                }
199                _ => (),
200            }
201        }
202    }
203    dropline
204}
205
206/// Cleanup all async drops (reset to sync)
207pub(super) fn cleanup_async_drops<'tcx>(body: &mut Body<'tcx>) {
208    for block in body.basic_blocks_mut() {
209        if let TerminatorKind::Drop {
210            place: _,
211            target: _,
212            unwind: _,
213            replace: _,
214            ref mut drop,
215            ref mut async_fut,
216        } = block.terminator_mut().kind
217        {
218            if drop.is_some() || async_fut.is_some() {
219                *drop = None;
220                *async_fut = None;
221            }
222        }
223    }
224}
225
226pub(super) fn has_expandable_async_drops<'tcx>(
227    tcx: TyCtxt<'tcx>,
228    body: &mut Body<'tcx>,
229    coroutine_ty: Ty<'tcx>,
230) -> bool {
231    for bb in START_BLOCK..body.basic_blocks.next_index() {
232        // Drops in unwind path (cleanup blocks) are not expanded to async drops, only sync drops in unwind path
233        if body[bb].is_cleanup {
234            continue;
235        }
236        let TerminatorKind::Drop { place, target: _, unwind: _, replace: _, drop: _, async_fut } =
237            body[bb].terminator().kind
238        else {
239            continue;
240        };
241        let place_ty = place.ty(&body.local_decls, tcx).ty;
242        if place_ty == coroutine_ty {
243            continue;
244        }
245        if async_fut.is_none() {
246            continue;
247        }
248        return true;
249    }
250    return false;
251}
252
253/// Expand Drop terminator for async drops into mainline poll-switch and dropline poll-switch
254#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
255pub(super) fn expand_async_drops<'tcx>(
256    tcx: TyCtxt<'tcx>,
257    body: &mut Body<'tcx>,
258    context_mut_ref: Ty<'tcx>,
259    coroutine_kind: hir::CoroutineKind,
260    coroutine_ty: Ty<'tcx>,
261) {
262    let dropline = gather_dropline_blocks(body);
263    // Clean drop and async_fut fields if potentially async drop is not expanded (stays sync)
264    let remove_asyncness = |block: &mut BasicBlockData<'tcx>| {
265        tracing::trace!("remove_asyncness");
266        if let TerminatorKind::Drop {
267            place: _,
268            target: _,
269            unwind: _,
270            replace: _,
271            ref mut drop,
272            ref mut async_fut,
273        } = block.terminator_mut().kind
274        {
275            *drop = None;
276            *async_fut = None;
277        }
278    };
279    for bb in START_BLOCK..body.basic_blocks.next_index() {
280        // Drops in unwind path (cleanup blocks) are not expanded to async drops, only sync drops in unwind path
281        if body[bb].is_cleanup {
282            remove_asyncness(&mut body[bb]);
283            continue;
284        }
285        let TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut } =
286            body[bb].terminator().kind
287        else {
288            continue;
289        };
290
291        let place_ty = place.ty(&body.local_decls, tcx).ty;
292        if place_ty == coroutine_ty {
293            remove_asyncness(&mut body[bb]);
294            continue;
295        }
296
297        let Some(fut_local) = async_fut else {
298            remove_asyncness(&mut body[bb]);
299            continue;
300        };
301
302        let is_dropline_bb = dropline.contains(bb);
303
304        if !is_dropline_bb && drop.is_none() {
305            remove_asyncness(&mut body[bb]);
306            continue;
307        }
308
309        let fut_place = Place::from(fut_local);
310        let fut_ty = fut_place.ty(&body.local_decls, tcx).ty;
311
312        // poll-code:
313        // state_call_drop:
314        // #bb_pin: fut_pin = Pin<FutT>::new_unchecked(&mut fut)
315        // #bb_call: rv = call fut.poll() (or future_drop_poll(fut) for internal future drops)
316        // #bb_check: match (rv)
317        //  pending => return rv (yield)
318        //  ready => *continue_bb|drop_bb*
319
320        let source_info = body[bb].terminator.as_ref().unwrap().source_info;
321
322        // Compute Poll<> (aka Poll with void return)
323        let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, source_info.span));
324        let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
325        let poll_decl = LocalDecl::new(poll_enum, source_info.span);
326        let poll_unit_place = Place::from(body.local_decls.push(poll_decl));
327
328        // First state-loop yield for mainline
329        let context_ref_place =
330            Place::from(body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)));
331        let arg = Rvalue::Use(Operand::Move(Place::from(CTX_ARG)));
332        body[bb].statements.push(Statement::new(
333            source_info,
334            StatementKind::Assign(Box::new((context_ref_place, arg))),
335        ));
336        let yield_block = insert_term_block(body, TerminatorKind::Unreachable); // `kind` replaced later to yield
337        let (pin_bb, fut_pin_place) =
338            build_pin_fut(tcx, body, fut_place.clone(), UnwindAction::Continue);
339        let switch_block = build_poll_switch(
340            tcx,
341            body,
342            poll_enum,
343            &poll_unit_place,
344            &fut_pin_place,
345            target,
346            yield_block,
347        );
348        let call_bb = build_poll_call(
349            tcx,
350            body,
351            &poll_unit_place,
352            switch_block,
353            &fut_pin_place,
354            fut_ty,
355            &context_ref_place,
356            unwind,
357        );
358
359        // Second state-loop yield for transition to dropline (when coroutine async drop started)
360        let mut dropline_transition_bb: Option<BasicBlock> = None;
361        let mut dropline_yield_bb: Option<BasicBlock> = None;
362        let mut dropline_context_ref: Option<Place<'_>> = None;
363        let mut dropline_call_bb: Option<BasicBlock> = None;
364        if !is_dropline_bb {
365            let context_ref_place2: Place<'_> = Place::from(
366                body.local_decls.push(LocalDecl::new(context_mut_ref, source_info.span)),
367            );
368            let drop_yield_block = insert_term_block(body, TerminatorKind::Unreachable); // `kind` replaced later to yield
369            let (pin_bb2, fut_pin_place2) =
370                build_pin_fut(tcx, body, fut_place, UnwindAction::Continue);
371            let drop_switch_block = build_poll_switch(
372                tcx,
373                body,
374                poll_enum,
375                &poll_unit_place,
376                &fut_pin_place2,
377                drop.unwrap(),
378                drop_yield_block,
379            );
380            let drop_call_bb = build_poll_call(
381                tcx,
382                body,
383                &poll_unit_place,
384                drop_switch_block,
385                &fut_pin_place2,
386                fut_ty,
387                &context_ref_place2,
388                unwind,
389            );
390            dropline_transition_bb = Some(pin_bb2);
391            dropline_yield_bb = Some(drop_yield_block);
392            dropline_context_ref = Some(context_ref_place2);
393            dropline_call_bb = Some(drop_call_bb);
394        }
395
396        let value =
397            if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _))
398            {
399                // For AsyncGen we need `yield Poll<OptRet>::Pending`
400                let full_yield_ty = body.yield_ty().unwrap();
401                let ty::Adt(_poll_adt, args) = *full_yield_ty.kind() else { bug!() };
402                let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
403                let yield_ty = args.type_at(0);
404                Operand::Constant(Box::new(ConstOperand {
405                    span: source_info.span,
406                    const_: Const::Unevaluated(
407                        UnevaluatedConst::new(
408                            tcx.require_lang_item(LangItem::AsyncGenPending, source_info.span),
409                            tcx.mk_args(&[yield_ty.into()]),
410                        ),
411                        full_yield_ty,
412                    ),
413                    user_ty: None,
414                }))
415            } else {
416                // value needed only for return-yields or gen-coroutines, so just const here
417                Operand::Constant(Box::new(ConstOperand {
418                    span: source_info.span,
419                    user_ty: None,
420                    const_: Const::from_bool(tcx, false),
421                }))
422            };
423
424        use rustc_middle::mir::AssertKind::ResumedAfterDrop;
425        let panic_bb = insert_panic_block(tcx, body, ResumedAfterDrop(coroutine_kind));
426
427        if is_dropline_bb {
428            body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
429                value: value.clone(),
430                resume: panic_bb,
431                resume_arg: context_ref_place,
432                drop: Some(pin_bb),
433            };
434        } else {
435            body[yield_block].terminator_mut().kind = TerminatorKind::Yield {
436                value: value.clone(),
437                resume: pin_bb,
438                resume_arg: context_ref_place,
439                drop: dropline_transition_bb,
440            };
441            body[dropline_yield_bb.unwrap()].terminator_mut().kind = TerminatorKind::Yield {
442                value,
443                resume: panic_bb,
444                resume_arg: dropline_context_ref.unwrap(),
445                drop: dropline_transition_bb,
446            };
447        }
448
449        if let TerminatorKind::Call { ref mut target, .. } = body[pin_bb].terminator_mut().kind {
450            *target = Some(call_bb);
451        } else {
452            bug!()
453        }
454        if !is_dropline_bb {
455            if let TerminatorKind::Call { ref mut target, .. } =
456                body[dropline_transition_bb.unwrap()].terminator_mut().kind
457            {
458                *target = dropline_call_bb;
459            } else {
460                bug!()
461            }
462        }
463
464        body[bb].terminator_mut().kind = TerminatorKind::Goto { target: pin_bb };
465    }
466}
467
468#[tracing::instrument(level = "trace", skip(tcx, body))]
469pub(super) fn elaborate_coroutine_drops<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
470    use crate::elaborate_drop::{Unwind, elaborate_drop};
471    use crate::patch::MirPatch;
472    use crate::shim::DropShimElaborator;
473
474    // Note that `elaborate_drops` only drops the upvars of a coroutine, and
475    // this is ok because `open_drop` can only be reached within that own
476    // coroutine's resume function.
477    let typing_env = body.typing_env(tcx);
478
479    let mut elaborator = DropShimElaborator {
480        body,
481        patch: MirPatch::new(body),
482        tcx,
483        typing_env,
484        produce_async_drops: false,
485    };
486
487    for (block, block_data) in body.basic_blocks.iter_enumerated() {
488        let (target, unwind, source_info, dropline) = match block_data.terminator() {
489            Terminator {
490                source_info,
491                kind: TerminatorKind::Drop { place, target, unwind, replace: _, drop, async_fut: _ },
492            } => {
493                if let Some(local) = place.as_local()
494                    && local == SELF_ARG
495                {
496                    (target, unwind, source_info, *drop)
497                } else {
498                    continue;
499                }
500            }
501            _ => continue,
502        };
503        let unwind = if block_data.is_cleanup {
504            Unwind::InCleanup
505        } else {
506            Unwind::To(match *unwind {
507                UnwindAction::Cleanup(tgt) => tgt,
508                UnwindAction::Continue => elaborator.patch.resume_block(),
509                UnwindAction::Unreachable => elaborator.patch.unreachable_cleanup_block(),
510                UnwindAction::Terminate(reason) => elaborator.patch.terminate_block(reason),
511            })
512        };
513        elaborate_drop(
514            &mut elaborator,
515            *source_info,
516            Place::from(SELF_ARG),
517            (),
518            *target,
519            unwind,
520            block,
521            dropline,
522        );
523    }
524    elaborator.patch.apply(body);
525}
526
527#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
528pub(super) fn insert_clean_drop<'tcx>(
529    tcx: TyCtxt<'tcx>,
530    body: &mut Body<'tcx>,
531    has_async_drops: bool,
532) -> BasicBlock {
533    let source_info = SourceInfo::outermost(body.span);
534    let return_block = if has_async_drops {
535        insert_poll_ready_block(tcx, body)
536    } else {
537        insert_term_block(body, TerminatorKind::Return)
538    };
539
540    // FIXME: When move insert_clean_drop + elaborate_coroutine_drops before async drops expand,
541    // also set dropline here:
542    // let dropline = if has_async_drops { Some(return_block) } else { None };
543    let dropline = None;
544
545    let term = TerminatorKind::Drop {
546        place: Place::from(SELF_ARG),
547        target: return_block,
548        unwind: UnwindAction::Continue,
549        replace: false,
550        drop: dropline,
551        async_fut: None,
552    };
553
554    // Create a block to destroy an unresumed coroutines. This can only destroy upvars.
555    body.basic_blocks_mut()
556        .push(BasicBlockData::new(Some(Terminator { source_info, kind: term }), false))
557}
558
559#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
560pub(super) fn create_coroutine_drop_shim<'tcx>(
561    tcx: TyCtxt<'tcx>,
562    transform: &TransformVisitor<'tcx>,
563    coroutine_ty: Ty<'tcx>,
564    body: &Body<'tcx>,
565    drop_clean: BasicBlock,
566) -> Body<'tcx> {
567    let mut body = body.clone();
568    // Take the coroutine info out of the body, since the drop shim is
569    // not a coroutine body itself; it just has its drop built out of it.
570    let _ = body.coroutine.take();
571    // Make sure the resume argument is not included here, since we're
572    // building a body for `drop_in_place`.
573    body.arg_count = 1;
574
575    let source_info = SourceInfo::outermost(body.span);
576
577    let mut cases = create_cases(&mut body, transform, Operation::Drop);
578
579    cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
580
581    // The returned state and the poisoned state fall through to the default
582    // case which is just to return
583
584    let default_block = insert_term_block(&mut body, TerminatorKind::Return);
585    insert_switch(&mut body, cases, transform, default_block);
586
587    for block in body.basic_blocks_mut() {
588        let kind = &mut block.terminator_mut().kind;
589        if let TerminatorKind::CoroutineDrop = *kind {
590            *kind = TerminatorKind::Return;
591        }
592    }
593
594    // Replace the return variable
595    body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(tcx.types.unit, source_info);
596
597    make_coroutine_state_argument_indirect(tcx, &mut body);
598
599    // Change the coroutine argument from &mut to *mut
600    body.local_decls[SELF_ARG] =
601        LocalDecl::with_source_info(Ty::new_mut_ptr(tcx, coroutine_ty), source_info);
602
603    // Make sure we remove dead blocks to remove
604    // unrelated code from the resume part of the function
605    simplify::remove_dead_blocks(&mut body);
606
607    // Update the body's def to become the drop glue.
608    let coroutine_instance = body.source.instance;
609    let drop_in_place = tcx.require_lang_item(LangItem::DropInPlace, body.span);
610    let drop_instance = InstanceKind::DropGlue(drop_in_place, Some(coroutine_ty));
611
612    // Temporary change MirSource to coroutine's instance so that dump_mir produces more sensible
613    // filename.
614    body.source.instance = coroutine_instance;
615    if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop", &body) {
616        dumper.dump_mir(&body);
617    }
618    body.source.instance = drop_instance;
619
620    // Creating a coroutine drop shim happens on `Analysis(PostCleanup) -> Runtime(Initial)`
621    // but the pass manager doesn't update the phase of the coroutine drop shim. Update the
622    // phase of the drop shim so that later on when we run the pass manager on the shim, in
623    // the `mir_shims` query, we don't ICE on the intra-pass validation before we've updated
624    // the phase of the body from analysis.
625    body.phase = MirPhase::Runtime(RuntimePhase::Initial);
626
627    body
628}
629
630// Create async drop shim function to drop coroutine itself
631#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
632pub(super) fn create_coroutine_drop_shim_async<'tcx>(
633    tcx: TyCtxt<'tcx>,
634    transform: &TransformVisitor<'tcx>,
635    body: &Body<'tcx>,
636    drop_clean: BasicBlock,
637    can_unwind: bool,
638) -> Body<'tcx> {
639    let mut body = body.clone();
640    // Take the coroutine info out of the body, since the drop shim is
641    // not a coroutine body itself; it just has its drop built out of it.
642    let _ = body.coroutine.take();
643
644    FixReturnPendingVisitor { tcx }.visit_body(&mut body);
645
646    // Poison the coroutine when it unwinds
647    if can_unwind {
648        generate_poison_block_and_redirect_unwinds_there(transform, &mut body);
649    }
650
651    let source_info = SourceInfo::outermost(body.span);
652
653    let mut cases = create_cases(&mut body, transform, Operation::Drop);
654
655    cases.insert(0, (CoroutineArgs::UNRESUMED, drop_clean));
656
657    use rustc_middle::mir::AssertKind::ResumedAfterPanic;
658    // Panic when resumed on the returned or poisoned state
659    if can_unwind {
660        cases.insert(
661            1,
662            (
663                CoroutineArgs::POISONED,
664                insert_panic_block(tcx, &mut body, ResumedAfterPanic(transform.coroutine_kind)),
665            ),
666        );
667    }
668
669    // RETURNED state also goes to default_block with `return Ready<()>`.
670    // For fully-polled coroutine, async drop has nothing to do.
671    let default_block = insert_poll_ready_block(tcx, &mut body);
672    insert_switch(&mut body, cases, transform, default_block);
673
674    for block in body.basic_blocks_mut() {
675        let kind = &mut block.terminator_mut().kind;
676        if let TerminatorKind::CoroutineDrop = *kind {
677            *kind = TerminatorKind::Return;
678            block.statements.push(return_poll_ready_assign(tcx, source_info));
679        }
680    }
681
682    // Replace the return variable: Poll<RetT> to Poll<()>
683    let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
684    let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
685    body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
686
687    make_coroutine_state_argument_indirect(tcx, &mut body);
688
689    match transform.coroutine_kind {
690        // Iterator::next doesn't accept a pinned argument,
691        // unlike for all other coroutine kinds.
692        CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {}
693        _ => {
694            make_coroutine_state_argument_pinned(tcx, &mut body);
695        }
696    }
697
698    // Make sure we remove dead blocks to remove
699    // unrelated code from the resume part of the function
700    simplify::remove_dead_blocks(&mut body);
701
702    pm::run_passes_no_validate(
703        tcx,
704        &mut body,
705        &[&abort_unwinding_calls::AbortUnwindingCalls],
706        None,
707    );
708
709    if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_async", &body) {
710        dumper.dump_mir(&body);
711    }
712
713    body
714}
715
716// Create async drop shim proxy function for future_drop_poll
717// It is just { call coroutine_drop(); return Poll::Ready(); }
718pub(super) fn create_coroutine_drop_shim_proxy_async<'tcx>(
719    tcx: TyCtxt<'tcx>,
720    body: &Body<'tcx>,
721) -> Body<'tcx> {
722    let mut body = body.clone();
723    // Take the coroutine info out of the body, since the drop shim is
724    // not a coroutine body itself; it just has its drop built out of it.
725    let _ = body.coroutine.take();
726    let basic_blocks: IndexVec<BasicBlock, BasicBlockData<'tcx>> = IndexVec::new();
727    body.basic_blocks = BasicBlocks::new(basic_blocks);
728    body.var_debug_info.clear();
729
730    // Keeping return value and args
731    body.local_decls.truncate(1 + body.arg_count);
732
733    let source_info = SourceInfo::outermost(body.span);
734
735    // Replace the return variable: Poll<RetT> to Poll<()>
736    let poll_adt_ref = tcx.adt_def(tcx.require_lang_item(LangItem::Poll, body.span));
737    let poll_enum = Ty::new_adt(tcx, poll_adt_ref, tcx.mk_args(&[tcx.types.unit.into()]));
738    body.local_decls[RETURN_PLACE] = LocalDecl::with_source_info(poll_enum, source_info);
739
740    // call coroutine_drop()
741    let call_bb = body.basic_blocks_mut().push(BasicBlockData::new(None, false));
742
743    // return Poll::Ready()
744    let ret_bb = insert_poll_ready_block(tcx, &mut body);
745
746    let kind = TerminatorKind::Drop {
747        place: Place::from(SELF_ARG),
748        target: ret_bb,
749        unwind: UnwindAction::Continue,
750        replace: false,
751        drop: None,
752        async_fut: None,
753    };
754    body.basic_blocks_mut()[call_bb].terminator = Some(Terminator { source_info, kind });
755
756    if let Some(dumper) = MirDumper::new(tcx, "coroutine_drop_proxy_async", &body) {
757        dumper.dump_mir(&body);
758    }
759
760    body
761}