1mod by_move_body;
54mod drop;
55use std::ops;
56
57pub(super) use by_move_body::coroutine_by_move_body_def_id;
58use drop::{
59 cleanup_async_drops, create_coroutine_drop_shim, create_coroutine_drop_shim_async,
60 create_coroutine_drop_shim_proxy_async, elaborate_coroutine_drops, expand_async_drops,
61 has_expandable_async_drops, insert_clean_drop,
62};
63use itertools::izip;
64use rustc_abi::{FieldIdx, VariantIdx};
65use rustc_data_structures::fx::FxHashSet;
66use rustc_errors::pluralize;
67use rustc_hir::lang_items::LangItem;
68use rustc_hir::{self as hir, CoroutineDesugaring, CoroutineKind, find_attr};
69use rustc_index::bit_set::{BitMatrix, DenseBitSet, GrowableBitSet};
70use rustc_index::{Idx, IndexVec, indexvec};
71use rustc_middle::mir::visit::{MutVisitor, MutatingUseContext, PlaceContext, Visitor};
72use rustc_middle::mir::*;
73use rustc_middle::ty::util::Discr;
74use rustc_middle::ty::{
75 self, CoroutineArgs, CoroutineArgsExt, GenericArgsRef, InstanceKind, Ty, TyCtxt, TypingMode,
76};
77use rustc_middle::{bug, span_bug};
78use rustc_mir_dataflow::impls::{
79 MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
80 always_storage_live_locals,
81};
82use rustc_mir_dataflow::{
83 Analysis, Results, ResultsCursor, ResultsVisitor, visit_reachable_results,
84};
85use rustc_span::def_id::{DefId, LocalDefId};
86use rustc_span::source_map::dummy_spanned;
87use rustc_span::{DUMMY_SP, Span};
88use rustc_trait_selection::error_reporting::InferCtxtErrorExt;
89use rustc_trait_selection::infer::TyCtxtInferExt as _;
90use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode, ObligationCtxt};
91use tracing::{debug, instrument, trace};
92
93use crate::deref_separator::deref_finder;
94use crate::{abort_unwinding_calls, errors, pass_manager as pm, simplify};
95
96pub(super) struct StateTransform;
97
98struct RenameLocalVisitor<'tcx> {
99 from: Local,
100 to: Local,
101 tcx: TyCtxt<'tcx>,
102}
103
104impl<'tcx> MutVisitor<'tcx> for RenameLocalVisitor<'tcx> {
105 fn tcx(&self) -> TyCtxt<'tcx> {
106 self.tcx
107 }
108
109 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
110 if *local == self.from {
111 *local = self.to;
112 } else if *local == self.to {
113 *local = self.from;
114 }
115 }
116
117 fn visit_terminator(&mut self, terminator: &mut Terminator<'tcx>, location: Location) {
118 match terminator.kind {
119 TerminatorKind::Return => {
120 }
123 _ => self.super_terminator(terminator, location),
124 }
125 }
126}
127
128struct SelfArgVisitor<'tcx> {
129 tcx: TyCtxt<'tcx>,
130 new_base: Place<'tcx>,
131}
132
133impl<'tcx> SelfArgVisitor<'tcx> {
134 fn new(tcx: TyCtxt<'tcx>, new_base: Place<'tcx>) -> Self {
135 Self { tcx, new_base }
136 }
137}
138
139impl<'tcx> MutVisitor<'tcx> for SelfArgVisitor<'tcx> {
140 fn tcx(&self) -> TyCtxt<'tcx> {
141 self.tcx
142 }
143
144 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _: Location) {
145 assert_ne!(*local, SELF_ARG);
146 }
147
148 fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _: Location) {
149 if place.local == SELF_ARG {
150 replace_base(place, self.new_base, self.tcx);
151 }
152
153 for elem in place.projection.iter() {
154 if let PlaceElem::Index(local) = elem {
155 assert_ne!(local, SELF_ARG);
156 }
157 }
158 }
159}
160
161#[tracing::instrument(level = "trace", skip(tcx))]
162fn replace_base<'tcx>(place: &mut Place<'tcx>, new_base: Place<'tcx>, tcx: TyCtxt<'tcx>) {
163 place.local = new_base.local;
164
165 let mut new_projection = new_base.projection.to_vec();
166 new_projection.append(&mut place.projection.to_vec());
167
168 place.projection = tcx.mk_place_elems(&new_projection);
169 tracing::trace!(?place);
170}
171
172const SELF_ARG: Local = Local::from_u32(1);
173const CTX_ARG: Local = Local::from_u32(2);
174
175struct SuspensionPoint<'tcx> {
177 state: usize,
179 resume: BasicBlock,
181 resume_arg: Place<'tcx>,
183 drop: Option<BasicBlock>,
185 storage_liveness: GrowableBitSet<Local>,
187}
188
189struct TransformVisitor<'tcx> {
190 tcx: TyCtxt<'tcx>,
191 coroutine_kind: hir::CoroutineKind,
192
193 discr_ty: Ty<'tcx>,
195
196 remap: IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
198
199 storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
201
202 suspension_points: Vec<SuspensionPoint<'tcx>>,
204
205 always_live_locals: DenseBitSet<Local>,
207
208 new_ret_local: Local,
210
211 old_yield_ty: Ty<'tcx>,
212
213 old_ret_ty: Ty<'tcx>,
214}
215
216impl<'tcx> TransformVisitor<'tcx> {
217 fn insert_none_ret_block(&self, body: &mut Body<'tcx>) -> BasicBlock {
218 let block = body.basic_blocks.next_index();
219 let source_info = SourceInfo::outermost(body.span);
220
221 let none_value = match self.coroutine_kind {
222 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
223 span_bug!(body.span, "`Future`s are not fused inherently")
224 }
225 CoroutineKind::Coroutine(_) => span_bug!(body.span, "`Coroutine`s cannot be fused"),
226 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
228 let option_def_id = self.tcx.require_lang_item(LangItem::Option, body.span);
229 make_aggregate_adt(
230 option_def_id,
231 VariantIdx::ZERO,
232 self.tcx.mk_args(&[self.old_yield_ty.into()]),
233 IndexVec::new(),
234 )
235 }
236 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
238 let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
239 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
240 let yield_ty = args.type_at(0);
241 Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
242 span: source_info.span,
243 const_: Const::Unevaluated(
244 UnevaluatedConst::new(
245 self.tcx.require_lang_item(LangItem::AsyncGenFinished, body.span),
246 self.tcx.mk_args(&[yield_ty.into()]),
247 ),
248 self.old_yield_ty,
249 ),
250 user_ty: None,
251 })))
252 }
253 };
254
255 let statements = vec![Statement::new(
256 source_info,
257 StatementKind::Assign(Box::new((Place::return_place(), none_value))),
258 )];
259
260 body.basic_blocks_mut().push(BasicBlockData::new_stmts(
261 statements,
262 Some(Terminator { source_info, kind: TerminatorKind::Return }),
263 false,
264 ));
265
266 block
267 }
268
269 #[tracing::instrument(level = "trace", skip(self, statements))]
275 fn make_state(
276 &self,
277 val: Operand<'tcx>,
278 source_info: SourceInfo,
279 is_return: bool,
280 statements: &mut Vec<Statement<'tcx>>,
281 ) {
282 const ZERO: VariantIdx = VariantIdx::ZERO;
283 const ONE: VariantIdx = VariantIdx::from_usize(1);
284 let rvalue = match self.coroutine_kind {
285 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
286 let poll_def_id = self.tcx.require_lang_item(LangItem::Poll, source_info.span);
287 let args = self.tcx.mk_args(&[self.old_ret_ty.into()]);
288 let (variant_idx, operands) = if is_return {
289 (ZERO, indexvec![val]) } else {
291 (ONE, IndexVec::new()) };
293 make_aggregate_adt(poll_def_id, variant_idx, args, operands)
294 }
295 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
296 let option_def_id = self.tcx.require_lang_item(LangItem::Option, source_info.span);
297 let args = self.tcx.mk_args(&[self.old_yield_ty.into()]);
298 let (variant_idx, operands) = if is_return {
299 (ZERO, IndexVec::new()) } else {
301 (ONE, indexvec![val]) };
303 make_aggregate_adt(option_def_id, variant_idx, args, operands)
304 }
305 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
306 if is_return {
307 let ty::Adt(_poll_adt, args) = *self.old_yield_ty.kind() else { bug!() };
308 let ty::Adt(_option_adt, args) = *args.type_at(0).kind() else { bug!() };
309 let yield_ty = args.type_at(0);
310 Rvalue::Use(Operand::Constant(Box::new(ConstOperand {
311 span: source_info.span,
312 const_: Const::Unevaluated(
313 UnevaluatedConst::new(
314 self.tcx.require_lang_item(
315 LangItem::AsyncGenFinished,
316 source_info.span,
317 ),
318 self.tcx.mk_args(&[yield_ty.into()]),
319 ),
320 self.old_yield_ty,
321 ),
322 user_ty: None,
323 })))
324 } else {
325 Rvalue::Use(val)
326 }
327 }
328 CoroutineKind::Coroutine(_) => {
329 let coroutine_state_def_id =
330 self.tcx.require_lang_item(LangItem::CoroutineState, source_info.span);
331 let args = self.tcx.mk_args(&[self.old_yield_ty.into(), self.old_ret_ty.into()]);
332 let variant_idx = if is_return {
333 ONE } else {
335 ZERO };
337 make_aggregate_adt(coroutine_state_def_id, variant_idx, args, indexvec![val])
338 }
339 };
340
341 statements.push(Statement::new(
343 source_info,
344 StatementKind::Assign(Box::new((self.new_ret_local.into(), rvalue))),
345 ));
346 }
347
348 #[tracing::instrument(level = "trace", skip(self), ret)]
350 fn make_field(&self, variant_index: VariantIdx, idx: FieldIdx, ty: Ty<'tcx>) -> Place<'tcx> {
351 let self_place = Place::from(SELF_ARG);
352 let base = self.tcx.mk_place_downcast_unnamed(self_place, variant_index);
353 let mut projection = base.projection.to_vec();
354 projection.push(ProjectionElem::Field(idx, ty));
355
356 Place { local: base.local, projection: self.tcx.mk_place_elems(&projection) }
357 }
358
359 #[tracing::instrument(level = "trace", skip(self))]
361 fn set_discr(&self, state_disc: VariantIdx, source_info: SourceInfo) -> Statement<'tcx> {
362 let self_place = Place::from(SELF_ARG);
363 Statement::new(
364 source_info,
365 StatementKind::SetDiscriminant {
366 place: Box::new(self_place),
367 variant_index: state_disc,
368 },
369 )
370 }
371
372 #[tracing::instrument(level = "trace", skip(self, body))]
374 fn get_discr(&self, body: &mut Body<'tcx>) -> (Statement<'tcx>, Place<'tcx>) {
375 let temp_decl = LocalDecl::new(self.discr_ty, body.span);
376 let local_decls_len = body.local_decls.push(temp_decl);
377 let temp = Place::from(local_decls_len);
378
379 let self_place = Place::from(SELF_ARG);
380 let assign = Statement::new(
381 SourceInfo::outermost(body.span),
382 StatementKind::Assign(Box::new((temp, Rvalue::Discriminant(self_place)))),
383 );
384 (assign, temp)
385 }
386
387 #[tracing::instrument(level = "trace", skip(self, body))]
389 fn replace_local(&mut self, old_local: Local, new_local: Local, body: &mut Body<'tcx>) {
390 body.local_decls.swap(old_local, new_local);
391
392 let mut visitor = RenameLocalVisitor { from: old_local, to: new_local, tcx: self.tcx };
393 visitor.visit_body(body);
394 for suspension in &mut self.suspension_points {
395 let ctxt = PlaceContext::MutatingUse(MutatingUseContext::Yield);
396 let location = Location { block: START_BLOCK, statement_index: 0 };
397 visitor.visit_place(&mut suspension.resume_arg, ctxt, location);
398 }
399 }
400}
401
402impl<'tcx> MutVisitor<'tcx> for TransformVisitor<'tcx> {
403 fn tcx(&self) -> TyCtxt<'tcx> {
404 self.tcx
405 }
406
407 #[tracing::instrument(level = "trace", skip(self), ret)]
408 fn visit_local(&mut self, local: &mut Local, _: PlaceContext, _location: Location) {
409 assert!(!self.remap.contains(*local));
410 }
411
412 #[tracing::instrument(level = "trace", skip(self), ret)]
413 fn visit_place(&mut self, place: &mut Place<'tcx>, _: PlaceContext, _location: Location) {
414 if let Some(&Some((ty, variant_index, idx))) = self.remap.get(place.local) {
416 replace_base(place, self.make_field(variant_index, idx, ty), self.tcx);
417 }
418 }
419
420 #[tracing::instrument(level = "trace", skip(self, stmt), ret)]
421 fn visit_statement(&mut self, stmt: &mut Statement<'tcx>, location: Location) {
422 if let StatementKind::StorageLive(l) | StatementKind::StorageDead(l) = stmt.kind
424 && self.remap.contains(l)
425 {
426 stmt.make_nop(true);
427 }
428 self.super_statement(stmt, location);
429 }
430
431 #[tracing::instrument(level = "trace", skip(self, term), ret)]
432 fn visit_terminator(&mut self, term: &mut Terminator<'tcx>, location: Location) {
433 if let TerminatorKind::Return = term.kind {
434 return;
437 }
438 self.super_terminator(term, location);
439 }
440
441 #[tracing::instrument(level = "trace", skip(self, data), ret)]
442 fn visit_basic_block_data(&mut self, block: BasicBlock, data: &mut BasicBlockData<'tcx>) {
443 match data.terminator().kind {
444 TerminatorKind::Return => {
445 let source_info = data.terminator().source_info;
446 self.make_state(
448 Operand::Move(Place::return_place()),
449 source_info,
450 true,
451 &mut data.statements,
452 );
453 let state = VariantIdx::new(CoroutineArgs::RETURNED);
455 data.statements.push(self.set_discr(state, source_info));
456 data.terminator_mut().kind = TerminatorKind::Return;
457 }
458 TerminatorKind::Yield { ref value, resume, mut resume_arg, drop } => {
459 let source_info = data.terminator().source_info;
460 self.make_state(value.clone(), source_info, false, &mut data.statements);
462 let state = CoroutineArgs::RESERVED_VARIANTS + self.suspension_points.len();
464
465 if let Some(&Some((ty, variant, idx))) = self.remap.get(resume_arg.local) {
468 replace_base(&mut resume_arg, self.make_field(variant, idx, ty), self.tcx);
469 }
470
471 let storage_liveness: GrowableBitSet<Local> =
472 self.storage_liveness[block].clone().unwrap().into();
473
474 for i in 0..self.always_live_locals.domain_size() {
475 let l = Local::new(i);
476 let needs_storage_dead = storage_liveness.contains(l)
477 && !self.remap.contains(l)
478 && !self.always_live_locals.contains(l);
479 if needs_storage_dead {
480 data.statements
481 .push(Statement::new(source_info, StatementKind::StorageDead(l)));
482 }
483 }
484
485 self.suspension_points.push(SuspensionPoint {
486 state,
487 resume,
488 resume_arg,
489 drop,
490 storage_liveness,
491 });
492
493 let state = VariantIdx::new(state);
494 data.statements.push(self.set_discr(state, source_info));
495 data.terminator_mut().kind = TerminatorKind::Return;
496 }
497 _ => {}
498 }
499
500 self.super_basic_block_data(block, data);
501 }
502}
503
504fn make_aggregate_adt<'tcx>(
505 def_id: DefId,
506 variant_idx: VariantIdx,
507 args: GenericArgsRef<'tcx>,
508 operands: IndexVec<FieldIdx, Operand<'tcx>>,
509) -> Rvalue<'tcx> {
510 Rvalue::Aggregate(Box::new(AggregateKind::Adt(def_id, variant_idx, args, None, None)), operands)
511}
512
513#[tracing::instrument(level = "trace", skip(tcx, body))]
514fn make_coroutine_state_argument_indirect<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
515 let coroutine_ty = body.local_decls[SELF_ARG].ty;
516
517 let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
518
519 body.local_decls[SELF_ARG].ty = ref_coroutine_ty;
521
522 SelfArgVisitor::new(tcx, tcx.mk_place_deref(SELF_ARG.into())).visit_body(body);
524}
525
526#[tracing::instrument(level = "trace", skip(tcx, body))]
527fn make_coroutine_state_argument_pinned<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
528 let coroutine_ty = body.local_decls[SELF_ARG].ty;
529
530 let ref_coroutine_ty = Ty::new_mut_ref(tcx, tcx.lifetimes.re_erased, coroutine_ty);
531
532 let pin_did = tcx.require_lang_item(LangItem::Pin, body.span);
533 let pin_adt_ref = tcx.adt_def(pin_did);
534 let args = tcx.mk_args(&[ref_coroutine_ty.into()]);
535 let pin_ref_coroutine_ty = Ty::new_adt(tcx, pin_adt_ref, args);
536
537 body.local_decls[SELF_ARG].ty = pin_ref_coroutine_ty;
539
540 let unpinned_local = body.local_decls.push(LocalDecl::new(ref_coroutine_ty, body.span));
541
542 SelfArgVisitor::new(tcx, tcx.mk_place_deref(unpinned_local.into())).visit_body(body);
544
545 let source_info = SourceInfo::outermost(body.span);
546 let pin_field = tcx.mk_place_field(SELF_ARG.into(), FieldIdx::ZERO, ref_coroutine_ty);
547
548 let statements = &mut body.basic_blocks.as_mut_preserves_cfg()[START_BLOCK].statements;
549 let insert_point = statements
552 .iter()
553 .position(|stmt| !matches!(stmt.kind, StatementKind::Retag(..)))
554 .unwrap_or(statements.len());
555 statements.insert(
556 insert_point,
557 Statement::new(
558 source_info,
559 StatementKind::Assign(Box::new((
560 unpinned_local.into(),
561 Rvalue::Use(Operand::Copy(pin_field)),
562 ))),
563 ),
564 );
565}
566
567#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
589fn transform_async_context<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> Ty<'tcx> {
590 let context_mut_ref = Ty::new_task_context(tcx);
591
592 replace_resume_ty_local(tcx, body, CTX_ARG, context_mut_ref);
594
595 let get_context_def_id = tcx.require_lang_item(LangItem::GetContext, body.span);
596
597 for bb in body.basic_blocks.indices() {
598 let bb_data = &body[bb];
599 if bb_data.is_cleanup {
600 continue;
601 }
602
603 match &bb_data.terminator().kind {
604 TerminatorKind::Call { func, .. } => {
605 let func_ty = func.ty(body, tcx);
606 if let ty::FnDef(def_id, _) = *func_ty.kind()
607 && def_id == get_context_def_id
608 {
609 let local = eliminate_get_context_call(&mut body[bb]);
610 replace_resume_ty_local(tcx, body, local, context_mut_ref);
611 }
612 }
613 TerminatorKind::Yield { resume_arg, .. } => {
614 replace_resume_ty_local(tcx, body, resume_arg.local, context_mut_ref);
615 }
616 _ => {}
617 }
618 }
619 context_mut_ref
620}
621
622fn eliminate_get_context_call<'tcx>(bb_data: &mut BasicBlockData<'tcx>) -> Local {
623 let terminator = bb_data.terminator.take().unwrap();
624 let TerminatorKind::Call { args, destination, target, .. } = terminator.kind else {
625 bug!();
626 };
627 let [arg] = *Box::try_from(args).unwrap();
628 let local = arg.node.place().unwrap().local;
629
630 let arg = Rvalue::Use(arg.node);
631 let assign =
632 Statement::new(terminator.source_info, StatementKind::Assign(Box::new((destination, arg))));
633 bb_data.statements.push(assign);
634 bb_data.terminator = Some(Terminator {
635 source_info: terminator.source_info,
636 kind: TerminatorKind::Goto { target: target.unwrap() },
637 });
638 local
639}
640
641#[cfg_attr(not(debug_assertions), allow(unused))]
642#[tracing::instrument(level = "trace", skip(tcx, body), ret)]
643fn replace_resume_ty_local<'tcx>(
644 tcx: TyCtxt<'tcx>,
645 body: &mut Body<'tcx>,
646 local: Local,
647 context_mut_ref: Ty<'tcx>,
648) {
649 let local_ty = std::mem::replace(&mut body.local_decls[local].ty, context_mut_ref);
650 #[cfg(debug_assertions)]
653 {
654 if let ty::Adt(resume_ty_adt, _) = local_ty.kind() {
655 let expected_adt = tcx.adt_def(tcx.require_lang_item(LangItem::ResumeTy, body.span));
656 assert_eq!(*resume_ty_adt, expected_adt);
657 } else {
658 panic!("expected `ResumeTy`, found `{:?}`", local_ty);
659 };
660 }
661}
662
663fn transform_gen_context<'tcx>(body: &mut Body<'tcx>) {
673 body.arg_count = 1;
677}
678
679struct LivenessInfo {
680 saved_locals: CoroutineSavedLocals,
682
683 live_locals_at_suspension_points: Vec<DenseBitSet<CoroutineSavedLocal>>,
685
686 source_info_at_suspension_points: Vec<SourceInfo>,
688
689 storage_conflicts: BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
693
694 storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
697}
698
699#[tracing::instrument(level = "trace", skip(tcx, body))]
708fn locals_live_across_suspend_points<'tcx>(
709 tcx: TyCtxt<'tcx>,
710 body: &Body<'tcx>,
711 always_live_locals: &DenseBitSet<Local>,
712 movable: bool,
713) -> LivenessInfo {
714 let mut storage_live = MaybeStorageLive::new(std::borrow::Cow::Borrowed(always_live_locals))
717 .iterate_to_fixpoint(tcx, body, None)
718 .into_results_cursor(body);
719
720 let borrowed_locals = MaybeBorrowedLocals.iterate_to_fixpoint(tcx, body, Some("coroutine"));
722 let borrowed_locals_cursor1 = ResultsCursor::new_borrowing(body, &borrowed_locals);
723 let mut borrowed_locals_cursor2 = ResultsCursor::new_borrowing(body, &borrowed_locals);
724
725 let requires_storage =
727 MaybeRequiresStorage::new(borrowed_locals_cursor1).iterate_to_fixpoint(tcx, body, None);
728 let mut requires_storage_cursor = ResultsCursor::new_borrowing(body, &requires_storage);
729
730 let mut liveness =
732 MaybeLiveLocals.iterate_to_fixpoint(tcx, body, Some("coroutine")).into_results_cursor(body);
733
734 let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks);
735 let mut live_locals_at_suspension_points = Vec::new();
736 let mut source_info_at_suspension_points = Vec::new();
737 let mut live_locals_at_any_suspension_point = DenseBitSet::new_empty(body.local_decls.len());
738
739 for (block, data) in body.basic_blocks.iter_enumerated() {
740 let TerminatorKind::Yield { .. } = data.terminator().kind else { continue };
741
742 let loc = Location { block, statement_index: data.statements.len() };
743
744 liveness.seek_to_block_end(block);
745 let mut live_locals = liveness.get().clone();
746
747 if !movable {
748 borrowed_locals_cursor2.seek_before_primary_effect(loc);
759 live_locals.union(borrowed_locals_cursor2.get());
760 }
761
762 storage_live.seek_before_primary_effect(loc);
765 storage_liveness_map[block] = Some(storage_live.get().clone());
766
767 requires_storage_cursor.seek_before_primary_effect(loc);
771 live_locals.intersect(requires_storage_cursor.get());
772
773 live_locals.remove(SELF_ARG);
775
776 debug!(?loc, ?live_locals);
777
778 live_locals_at_any_suspension_point.union(&live_locals);
781
782 live_locals_at_suspension_points.push(live_locals);
783 source_info_at_suspension_points.push(data.terminator().source_info);
784 }
785
786 debug!(?live_locals_at_any_suspension_point);
787 let saved_locals = CoroutineSavedLocals(live_locals_at_any_suspension_point);
788
789 let live_locals_at_suspension_points = live_locals_at_suspension_points
792 .iter()
793 .map(|live_here| saved_locals.renumber_bitset(live_here))
794 .collect();
795
796 let storage_conflicts = compute_storage_conflicts(
797 body,
798 &saved_locals,
799 always_live_locals.clone(),
800 &requires_storage,
801 );
802
803 LivenessInfo {
804 saved_locals,
805 live_locals_at_suspension_points,
806 source_info_at_suspension_points,
807 storage_conflicts,
808 storage_liveness: storage_liveness_map,
809 }
810}
811
812struct CoroutineSavedLocals(DenseBitSet<Local>);
818
819impl CoroutineSavedLocals {
820 fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (CoroutineSavedLocal, Local)> {
823 self.iter().enumerate().map(|(i, l)| (CoroutineSavedLocal::from(i), l))
824 }
825
826 fn renumber_bitset(&self, input: &DenseBitSet<Local>) -> DenseBitSet<CoroutineSavedLocal> {
829 assert!(self.superset(input), "{:?} not a superset of {:?}", self.0, input);
830 let mut out = DenseBitSet::new_empty(self.count());
831 for (saved_local, local) in self.iter_enumerated() {
832 if input.contains(local) {
833 out.insert(saved_local);
834 }
835 }
836 out
837 }
838
839 fn get(&self, local: Local) -> Option<CoroutineSavedLocal> {
840 if !self.contains(local) {
841 return None;
842 }
843
844 let idx = self.iter().take_while(|&l| l < local).count();
845 Some(CoroutineSavedLocal::new(idx))
846 }
847}
848
849impl ops::Deref for CoroutineSavedLocals {
850 type Target = DenseBitSet<Local>;
851
852 fn deref(&self) -> &Self::Target {
853 &self.0
854 }
855}
856
857fn compute_storage_conflicts<'mir, 'tcx>(
862 body: &'mir Body<'tcx>,
863 saved_locals: &'mir CoroutineSavedLocals,
864 always_live_locals: DenseBitSet<Local>,
865 results: &Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>,
866) -> BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal> {
867 assert_eq!(body.local_decls.len(), saved_locals.domain_size());
868
869 debug!("compute_storage_conflicts({:?})", body.span);
870 debug!("always_live = {:?}", always_live_locals);
871
872 let mut ineligible_locals = always_live_locals;
875 ineligible_locals.intersect(&**saved_locals);
876
877 let mut visitor = StorageConflictVisitor {
879 body,
880 saved_locals,
881 local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()),
882 eligible_storage_live: DenseBitSet::new_empty(body.local_decls.len()),
883 };
884
885 visit_reachable_results(body, results, &mut visitor);
886
887 let local_conflicts = visitor.local_conflicts;
888
889 let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count());
897 for (saved_local_a, local_a) in saved_locals.iter_enumerated() {
898 if ineligible_locals.contains(local_a) {
899 storage_conflicts.insert_all_into_row(saved_local_a);
901 } else {
902 for (saved_local_b, local_b) in saved_locals.iter_enumerated() {
904 if local_conflicts.contains(local_a, local_b) {
905 storage_conflicts.insert(saved_local_a, saved_local_b);
906 }
907 }
908 }
909 }
910 storage_conflicts
911}
912
913struct StorageConflictVisitor<'a, 'tcx> {
914 body: &'a Body<'tcx>,
915 saved_locals: &'a CoroutineSavedLocals,
916 local_conflicts: BitMatrix<Local, Local>,
919 eligible_storage_live: DenseBitSet<Local>,
921}
922
923impl<'a, 'tcx> ResultsVisitor<'tcx, MaybeRequiresStorage<'a, 'tcx>>
924 for StorageConflictVisitor<'a, 'tcx>
925{
926 fn visit_after_early_statement_effect(
927 &mut self,
928 _analysis: &MaybeRequiresStorage<'a, 'tcx>,
929 state: &DenseBitSet<Local>,
930 _statement: &Statement<'tcx>,
931 loc: Location,
932 ) {
933 self.apply_state(state, loc);
934 }
935
936 fn visit_after_early_terminator_effect(
937 &mut self,
938 _analysis: &MaybeRequiresStorage<'a, 'tcx>,
939 state: &DenseBitSet<Local>,
940 _terminator: &Terminator<'tcx>,
941 loc: Location,
942 ) {
943 self.apply_state(state, loc);
944 }
945}
946
947impl StorageConflictVisitor<'_, '_> {
948 fn apply_state(&mut self, state: &DenseBitSet<Local>, loc: Location) {
949 if let TerminatorKind::Unreachable = self.body.basic_blocks[loc.block].terminator().kind {
951 return;
952 }
953
954 self.eligible_storage_live.clone_from(state);
955 self.eligible_storage_live.intersect(&**self.saved_locals);
956
957 for local in self.eligible_storage_live.iter() {
958 self.local_conflicts.union_row_with(&self.eligible_storage_live, local);
959 }
960
961 if self.eligible_storage_live.count() > 1 {
962 trace!("at {:?}, eligible_storage_live={:?}", loc, self.eligible_storage_live);
963 }
964 }
965}
966
967#[tracing::instrument(level = "trace", skip(liveness, body))]
968fn compute_layout<'tcx>(
969 liveness: LivenessInfo,
970 body: &Body<'tcx>,
971) -> (
972 IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
973 CoroutineLayout<'tcx>,
974 IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
975) {
976 let LivenessInfo {
977 saved_locals,
978 live_locals_at_suspension_points,
979 source_info_at_suspension_points,
980 storage_conflicts,
981 storage_liveness,
982 } = liveness;
983
984 let mut locals = IndexVec::<CoroutineSavedLocal, _>::with_capacity(saved_locals.domain_size());
986 let mut tys = IndexVec::<CoroutineSavedLocal, _>::with_capacity(saved_locals.domain_size());
987 for (saved_local, local) in saved_locals.iter_enumerated() {
988 debug!("coroutine saved local {:?} => {:?}", saved_local, local);
989
990 locals.push(local);
991 let decl = &body.local_decls[local];
992 debug!(?decl);
993
994 let ignore_for_traits = match decl.local_info {
999 ClearCrossCrate::Set(box LocalInfo::StaticRef { is_thread_local, .. }) => {
1002 !is_thread_local
1003 }
1004 ClearCrossCrate::Set(box LocalInfo::FakeBorrow) => true,
1007 _ => false,
1008 };
1009 let decl =
1010 CoroutineSavedTy { ty: decl.ty, source_info: decl.source_info, ignore_for_traits };
1011 debug!(?decl);
1012
1013 tys.push(decl);
1014 }
1015
1016 let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span;
1020 let mut variant_source_info: IndexVec<VariantIdx, SourceInfo> = IndexVec::with_capacity(
1021 CoroutineArgs::RESERVED_VARIANTS + live_locals_at_suspension_points.len(),
1022 );
1023 variant_source_info.extend([
1024 SourceInfo::outermost(body_span.shrink_to_lo()),
1025 SourceInfo::outermost(body_span.shrink_to_hi()),
1026 SourceInfo::outermost(body_span.shrink_to_hi()),
1027 ]);
1028
1029 let mut variant_fields: IndexVec<VariantIdx, _> = IndexVec::from_elem_n(
1032 IndexVec::new(),
1033 CoroutineArgs::RESERVED_VARIANTS + live_locals_at_suspension_points.len(),
1034 );
1035 let mut remap = IndexVec::from_elem_n(None, saved_locals.domain_size());
1036 for (live_locals, &source_info_at_suspension_point, (variant_index, fields)) in izip!(
1037 &live_locals_at_suspension_points,
1038 &source_info_at_suspension_points,
1039 variant_fields.iter_enumerated_mut().skip(CoroutineArgs::RESERVED_VARIANTS)
1040 ) {
1041 *fields = live_locals.iter().collect();
1042 for (idx, &saved_local) in fields.iter_enumerated() {
1043 remap[locals[saved_local]] = Some((tys[saved_local].ty, variant_index, idx));
1048 }
1049 variant_source_info.push(source_info_at_suspension_point);
1050 }
1051 debug!(?variant_fields);
1052 debug!(?storage_conflicts);
1053
1054 let mut field_names = IndexVec::from_elem(None, &tys);
1055 for var in &body.var_debug_info {
1056 let VarDebugInfoContents::Place(place) = &var.value else { continue };
1057 let Some(local) = place.as_local() else { continue };
1058 let Some(&Some((_, variant, field))) = remap.get(local) else {
1059 continue;
1060 };
1061
1062 let saved_local = variant_fields[variant][field];
1063 field_names.get_or_insert_with(saved_local, || var.name);
1064 }
1065
1066 let layout = CoroutineLayout {
1067 field_tys: tys,
1068 field_names,
1069 variant_fields,
1070 variant_source_info,
1071 storage_conflicts,
1072 };
1073 debug!(?remap);
1074 debug!(?layout);
1075 debug!(?storage_liveness);
1076
1077 (remap, layout, storage_liveness)
1078}
1079
1080fn insert_switch<'tcx>(
1085 body: &mut Body<'tcx>,
1086 cases: Vec<(usize, BasicBlock)>,
1087 transform: &TransformVisitor<'tcx>,
1088 default_block: BasicBlock,
1089) {
1090 let (assign, discr) = transform.get_discr(body);
1091 let switch_targets =
1092 SwitchTargets::new(cases.iter().map(|(i, bb)| ((*i) as u128, *bb)), default_block);
1093 let switch = TerminatorKind::SwitchInt { discr: Operand::Move(discr), targets: switch_targets };
1094
1095 let source_info = SourceInfo::outermost(body.span);
1096 body.basic_blocks_mut().raw.insert(
1097 0,
1098 BasicBlockData::new_stmts(
1099 vec![assign],
1100 Some(Terminator { source_info, kind: switch }),
1101 false,
1102 ),
1103 );
1104
1105 for b in body.basic_blocks_mut().iter_mut() {
1106 b.terminator_mut().successors_mut(|target| *target += 1);
1107 }
1108}
1109
1110fn insert_term_block<'tcx>(body: &mut Body<'tcx>, kind: TerminatorKind<'tcx>) -> BasicBlock {
1111 let source_info = SourceInfo::outermost(body.span);
1112 body.basic_blocks_mut().push(BasicBlockData::new(Some(Terminator { source_info, kind }), false))
1113}
1114
1115fn return_poll_ready_assign<'tcx>(tcx: TyCtxt<'tcx>, source_info: SourceInfo) -> Statement<'tcx> {
1116 let poll_def_id = tcx.require_lang_item(LangItem::Poll, source_info.span);
1118 let args = tcx.mk_args(&[tcx.types.unit.into()]);
1119 let val = Operand::Constant(Box::new(ConstOperand {
1120 span: source_info.span,
1121 user_ty: None,
1122 const_: Const::zero_sized(tcx.types.unit),
1123 }));
1124 let ready_val = Rvalue::Aggregate(
1125 Box::new(AggregateKind::Adt(poll_def_id, VariantIdx::from_usize(0), args, None, None)),
1126 indexvec![val],
1127 );
1128 Statement::new(source_info, StatementKind::Assign(Box::new((Place::return_place(), ready_val))))
1129}
1130
1131fn insert_poll_ready_block<'tcx>(tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) -> BasicBlock {
1132 let source_info = SourceInfo::outermost(body.span);
1133 body.basic_blocks_mut().push(BasicBlockData::new_stmts(
1134 [return_poll_ready_assign(tcx, source_info)].to_vec(),
1135 Some(Terminator { source_info, kind: TerminatorKind::Return }),
1136 false,
1137 ))
1138}
1139
1140fn insert_panic_block<'tcx>(
1141 tcx: TyCtxt<'tcx>,
1142 body: &mut Body<'tcx>,
1143 message: AssertMessage<'tcx>,
1144) -> BasicBlock {
1145 let assert_block = body.basic_blocks.next_index();
1146 let kind = TerminatorKind::Assert {
1147 cond: Operand::Constant(Box::new(ConstOperand {
1148 span: body.span,
1149 user_ty: None,
1150 const_: Const::from_bool(tcx, false),
1151 })),
1152 expected: true,
1153 msg: Box::new(message),
1154 target: assert_block,
1155 unwind: UnwindAction::Continue,
1156 };
1157
1158 insert_term_block(body, kind)
1159}
1160
1161fn can_return<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>, typing_env: ty::TypingEnv<'tcx>) -> bool {
1162 if body.return_ty().is_privately_uninhabited(tcx, typing_env) {
1164 return false;
1165 }
1166
1167 body.basic_blocks.iter().any(|block| matches!(block.terminator().kind, TerminatorKind::Return))
1169 }
1171
1172fn can_unwind<'tcx>(tcx: TyCtxt<'tcx>, body: &Body<'tcx>) -> bool {
1173 if !tcx.sess.panic_strategy().unwinds() {
1175 return false;
1176 }
1177
1178 for block in body.basic_blocks.iter() {
1180 match block.terminator().kind {
1181 TerminatorKind::Goto { .. }
1183 | TerminatorKind::SwitchInt { .. }
1184 | TerminatorKind::UnwindTerminate(_)
1185 | TerminatorKind::Return
1186 | TerminatorKind::Unreachable
1187 | TerminatorKind::CoroutineDrop
1188 | TerminatorKind::FalseEdge { .. }
1189 | TerminatorKind::FalseUnwind { .. } => {}
1190
1191 TerminatorKind::UnwindResume => {}
1194
1195 TerminatorKind::Yield { .. } => {
1196 unreachable!("`can_unwind` called before coroutine transform")
1197 }
1198
1199 TerminatorKind::Drop { .. }
1201 | TerminatorKind::Call { .. }
1202 | TerminatorKind::InlineAsm { .. }
1203 | TerminatorKind::Assert { .. } => return true,
1204
1205 TerminatorKind::TailCall { .. } => {
1206 unreachable!("tail calls can't be present in generators")
1207 }
1208 }
1209 }
1210
1211 false
1213}
1214
1215fn generate_poison_block_and_redirect_unwinds_there<'tcx>(
1217 transform: &TransformVisitor<'tcx>,
1218 body: &mut Body<'tcx>,
1219) {
1220 let source_info = SourceInfo::outermost(body.span);
1221 let poison_block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
1222 vec![transform.set_discr(VariantIdx::new(CoroutineArgs::POISONED), source_info)],
1223 Some(Terminator { source_info, kind: TerminatorKind::UnwindResume }),
1224 true,
1225 ));
1226
1227 for (idx, block) in body.basic_blocks_mut().iter_enumerated_mut() {
1228 let source_info = block.terminator().source_info;
1229
1230 if let TerminatorKind::UnwindResume = block.terminator().kind {
1231 if idx != poison_block {
1234 *block.terminator_mut() =
1235 Terminator { source_info, kind: TerminatorKind::Goto { target: poison_block } };
1236 }
1237 } else if !block.is_cleanup
1238 && let Some(unwind @ UnwindAction::Continue) = block.terminator_mut().unwind_mut()
1241 {
1242 *unwind = UnwindAction::Cleanup(poison_block);
1243 }
1244 }
1245}
1246
1247#[tracing::instrument(level = "trace", skip(tcx, transform, body))]
1248fn create_coroutine_resume_function<'tcx>(
1249 tcx: TyCtxt<'tcx>,
1250 transform: TransformVisitor<'tcx>,
1251 body: &mut Body<'tcx>,
1252 can_return: bool,
1253 can_unwind: bool,
1254) {
1255 if can_unwind {
1257 generate_poison_block_and_redirect_unwinds_there(&transform, body);
1258 }
1259
1260 let mut cases = create_cases(body, &transform, Operation::Resume);
1261
1262 use rustc_middle::mir::AssertKind::{ResumedAfterPanic, ResumedAfterReturn};
1263
1264 cases.insert(0, (CoroutineArgs::UNRESUMED, START_BLOCK));
1266
1267 if can_unwind {
1269 cases.insert(
1270 1,
1271 (
1272 CoroutineArgs::POISONED,
1273 insert_panic_block(tcx, body, ResumedAfterPanic(transform.coroutine_kind)),
1274 ),
1275 );
1276 }
1277
1278 if can_return {
1279 let block = match transform.coroutine_kind {
1280 CoroutineKind::Desugared(CoroutineDesugaring::Async, _)
1281 | CoroutineKind::Coroutine(_) => {
1282 if tcx.is_async_drop_in_place_coroutine(body.source.def_id()) {
1285 insert_poll_ready_block(tcx, body)
1286 } else {
1287 insert_panic_block(tcx, body, ResumedAfterReturn(transform.coroutine_kind))
1288 }
1289 }
1290 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _)
1291 | CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1292 transform.insert_none_ret_block(body)
1293 }
1294 };
1295 cases.insert(1, (CoroutineArgs::RETURNED, block));
1296 }
1297
1298 let default_block = insert_term_block(body, TerminatorKind::Unreachable);
1299 insert_switch(body, cases, &transform, default_block);
1300
1301 match transform.coroutine_kind {
1302 CoroutineKind::Coroutine(_)
1303 | CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _) =>
1304 {
1305 make_coroutine_state_argument_pinned(tcx, body);
1306 }
1307 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1310 make_coroutine_state_argument_indirect(tcx, body);
1311 }
1312 }
1313
1314 simplify::remove_dead_blocks(body);
1317
1318 pm::run_passes_no_validate(tcx, body, &[&abort_unwinding_calls::AbortUnwindingCalls], None);
1319
1320 if let Some(dumper) = MirDumper::new(tcx, "coroutine_resume", body) {
1321 dumper.dump_mir(body);
1322 }
1323}
1324
1325#[derive(PartialEq, Copy, Clone, Debug)]
1327enum Operation {
1328 Resume,
1329 Drop,
1330}
1331
1332impl Operation {
1333 fn target_block(self, point: &SuspensionPoint<'_>) -> Option<BasicBlock> {
1334 match self {
1335 Operation::Resume => Some(point.resume),
1336 Operation::Drop => point.drop,
1337 }
1338 }
1339}
1340
1341#[tracing::instrument(level = "trace", skip(transform, body))]
1342fn create_cases<'tcx>(
1343 body: &mut Body<'tcx>,
1344 transform: &TransformVisitor<'tcx>,
1345 operation: Operation,
1346) -> Vec<(usize, BasicBlock)> {
1347 let source_info = SourceInfo::outermost(body.span);
1348
1349 transform
1350 .suspension_points
1351 .iter()
1352 .filter_map(|point| {
1353 operation.target_block(point).map(|target| {
1355 let mut statements = Vec::new();
1356
1357 for l in body.local_decls.indices() {
1359 let needs_storage_live = point.storage_liveness.contains(l)
1360 && !transform.remap.contains(l)
1361 && !transform.always_live_locals.contains(l);
1362 if needs_storage_live {
1363 statements.push(Statement::new(source_info, StatementKind::StorageLive(l)));
1364 }
1365 }
1366
1367 if operation == Operation::Resume && point.resume_arg != CTX_ARG.into() {
1368 statements.push(Statement::new(
1370 source_info,
1371 StatementKind::Assign(Box::new((
1372 point.resume_arg,
1373 Rvalue::Use(Operand::Move(CTX_ARG.into())),
1374 ))),
1375 ));
1376 }
1377
1378 let block = body.basic_blocks_mut().push(BasicBlockData::new_stmts(
1380 statements,
1381 Some(Terminator { source_info, kind: TerminatorKind::Goto { target } }),
1382 false,
1383 ));
1384
1385 (point.state, block)
1386 })
1387 })
1388 .collect()
1389}
1390
1391#[instrument(level = "debug", skip(tcx), ret)]
1392pub(crate) fn mir_coroutine_witnesses<'tcx>(
1393 tcx: TyCtxt<'tcx>,
1394 def_id: LocalDefId,
1395) -> Option<CoroutineLayout<'tcx>> {
1396 let (body, _) = tcx.mir_promoted(def_id);
1397 let body = body.borrow();
1398 let body = &*body;
1399
1400 let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
1402
1403 let movable = match *coroutine_ty.kind() {
1404 ty::Coroutine(def_id, _) => tcx.coroutine_movability(def_id) == hir::Movability::Movable,
1405 ty::Error(_) => return None,
1406 _ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty),
1407 };
1408
1409 let always_live_locals = always_storage_live_locals(body);
1412 let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1413
1414 let (_, coroutine_layout, _) = compute_layout(liveness_info, body);
1418
1419 check_suspend_tys(tcx, &coroutine_layout, body);
1420 check_field_tys_sized(tcx, &coroutine_layout, def_id);
1421
1422 Some(coroutine_layout)
1423}
1424
1425fn check_field_tys_sized<'tcx>(
1426 tcx: TyCtxt<'tcx>,
1427 coroutine_layout: &CoroutineLayout<'tcx>,
1428 def_id: LocalDefId,
1429) {
1430 if !tcx.features().unsized_fn_params() {
1433 return;
1434 }
1435
1436 let infcx = tcx.infer_ctxt().ignoring_regions().build(TypingMode::non_body_analysis());
1441 let param_env = tcx.param_env(def_id);
1442
1443 let ocx = ObligationCtxt::new_with_diagnostics(&infcx);
1444 for field_ty in &coroutine_layout.field_tys {
1445 ocx.register_bound(
1446 ObligationCause::new(
1447 field_ty.source_info.span,
1448 def_id,
1449 ObligationCauseCode::SizedCoroutineInterior(def_id),
1450 ),
1451 param_env,
1452 field_ty.ty,
1453 tcx.require_lang_item(hir::LangItem::Sized, field_ty.source_info.span),
1454 );
1455 }
1456
1457 let errors = ocx.evaluate_obligations_error_on_ambiguity();
1458 debug!(?errors);
1459 if !errors.is_empty() {
1460 infcx.err_ctxt().report_fulfillment_errors(errors);
1461 }
1462}
1463
1464impl<'tcx> crate::MirPass<'tcx> for StateTransform {
1465 #[instrument(level = "debug", skip(self, tcx, body), ret)]
1466 fn run_pass(&self, tcx: TyCtxt<'tcx>, body: &mut Body<'tcx>) {
1467 debug!(def_id = ?body.source.def_id());
1468
1469 let Some(old_yield_ty) = body.yield_ty() else {
1470 return;
1472 };
1473 tracing::trace!(def_id = ?body.source.def_id());
1474
1475 let old_ret_ty = body.return_ty();
1476
1477 assert!(body.coroutine_drop().is_none() && body.coroutine_drop_async().is_none());
1478
1479 if let Some(dumper) = MirDumper::new(tcx, "coroutine_before", body) {
1480 dumper.dump_mir(body);
1481 }
1482
1483 let coroutine_ty = body.local_decls.raw[1].ty;
1485 let coroutine_kind = body.coroutine_kind().unwrap();
1486
1487 let ty::Coroutine(_, args) = coroutine_ty.kind() else {
1489 tcx.dcx().span_bug(body.span, format!("unexpected coroutine type {coroutine_ty}"));
1490 };
1491 let discr_ty = args.as_coroutine().discr_ty(tcx);
1492
1493 let new_ret_ty = match coroutine_kind {
1494 CoroutineKind::Desugared(CoroutineDesugaring::Async, _) => {
1495 let poll_did = tcx.require_lang_item(LangItem::Poll, body.span);
1497 let poll_adt_ref = tcx.adt_def(poll_did);
1498 let poll_args = tcx.mk_args(&[old_ret_ty.into()]);
1499 Ty::new_adt(tcx, poll_adt_ref, poll_args)
1500 }
1501 CoroutineKind::Desugared(CoroutineDesugaring::Gen, _) => {
1502 let option_did = tcx.require_lang_item(LangItem::Option, body.span);
1504 let option_adt_ref = tcx.adt_def(option_did);
1505 let option_args = tcx.mk_args(&[old_yield_ty.into()]);
1506 Ty::new_adt(tcx, option_adt_ref, option_args)
1507 }
1508 CoroutineKind::Desugared(CoroutineDesugaring::AsyncGen, _) => {
1509 old_yield_ty
1511 }
1512 CoroutineKind::Coroutine(_) => {
1513 let state_did = tcx.require_lang_item(LangItem::CoroutineState, body.span);
1515 let state_adt_ref = tcx.adt_def(state_did);
1516 let state_args = tcx.mk_args(&[old_yield_ty.into(), old_ret_ty.into()]);
1517 Ty::new_adt(tcx, state_adt_ref, state_args)
1518 }
1519 };
1520
1521 let has_async_drops = matches!(
1526 coroutine_kind,
1527 CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
1528 ) && has_expandable_async_drops(tcx, body, coroutine_ty);
1529
1530 if matches!(
1532 coroutine_kind,
1533 CoroutineKind::Desugared(CoroutineDesugaring::Async | CoroutineDesugaring::AsyncGen, _)
1534 ) {
1535 let context_mut_ref = transform_async_context(tcx, body);
1536 expand_async_drops(tcx, body, context_mut_ref, coroutine_kind, coroutine_ty);
1537
1538 if let Some(dumper) = MirDumper::new(tcx, "coroutine_async_drop_expand", body) {
1539 dumper.dump_mir(body);
1540 }
1541 } else {
1542 cleanup_async_drops(body);
1543 }
1544
1545 let always_live_locals = always_storage_live_locals(body);
1546 let movable = coroutine_kind.movability() == hir::Movability::Movable;
1547 let liveness_info =
1548 locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
1549
1550 if tcx.sess.opts.unstable_opts.validate_mir {
1551 let mut vis = EnsureCoroutineFieldAssignmentsNeverAlias {
1552 assigned_local: None,
1553 saved_locals: &liveness_info.saved_locals,
1554 storage_conflicts: &liveness_info.storage_conflicts,
1555 };
1556
1557 vis.visit_body(body);
1558 }
1559
1560 let (remap, layout, storage_liveness) = compute_layout(liveness_info, body);
1564
1565 let can_return = can_return(tcx, body, body.typing_env(tcx));
1566
1567 let new_ret_local = body.local_decls.push(LocalDecl::new(new_ret_ty, body.span));
1570 tracing::trace!(?new_ret_local);
1571
1572 let mut transform = TransformVisitor {
1578 tcx,
1579 coroutine_kind,
1580 remap,
1581 storage_liveness,
1582 always_live_locals,
1583 suspension_points: Vec::new(),
1584 discr_ty,
1585 new_ret_local,
1586 old_ret_ty,
1587 old_yield_ty,
1588 };
1589 transform.visit_body(body);
1590
1591 transform.replace_local(RETURN_PLACE, new_ret_local, body);
1593
1594 let source_info = SourceInfo::outermost(body.span);
1597 let args_iter = body.args_iter();
1598 body.basic_blocks.as_mut()[START_BLOCK].statements.splice(
1599 0..0,
1600 args_iter.filter_map(|local| {
1601 let (ty, variant_index, idx) = transform.remap[local]?;
1602 let lhs = transform.make_field(variant_index, idx, ty);
1603 let rhs = Rvalue::Use(Operand::Move(local.into()));
1604 let assign = StatementKind::Assign(Box::new((lhs, rhs)));
1605 Some(Statement::new(source_info, assign))
1606 }),
1607 );
1608
1609 body.arg_count = 2; body.spread_arg = None;
1612
1613 if matches!(coroutine_kind, CoroutineKind::Desugared(CoroutineDesugaring::Gen, _)) {
1615 transform_gen_context(body);
1616 }
1617
1618 for var in &mut body.var_debug_info {
1622 var.argument_index = None;
1623 }
1624
1625 body.coroutine.as_mut().unwrap().yield_ty = None;
1626 body.coroutine.as_mut().unwrap().resume_ty = None;
1627 body.coroutine.as_mut().unwrap().coroutine_layout = Some(layout);
1628
1629 let drop_clean = insert_clean_drop(tcx, body, has_async_drops);
1637
1638 if let Some(dumper) = MirDumper::new(tcx, "coroutine_pre-elab", body) {
1639 dumper.dump_mir(body);
1640 }
1641
1642 elaborate_coroutine_drops(tcx, body);
1646
1647 if let Some(dumper) = MirDumper::new(tcx, "coroutine_post-transform", body) {
1648 dumper.dump_mir(body);
1649 }
1650
1651 let can_unwind = can_unwind(tcx, body);
1652
1653 if has_async_drops {
1655 let mut drop_shim =
1657 create_coroutine_drop_shim_async(tcx, &transform, body, drop_clean, can_unwind);
1658 deref_finder(tcx, &mut drop_shim, false);
1660 body.coroutine.as_mut().unwrap().coroutine_drop_async = Some(drop_shim);
1661 } else {
1662 let mut drop_shim =
1664 create_coroutine_drop_shim(tcx, &transform, coroutine_ty, body, drop_clean);
1665 deref_finder(tcx, &mut drop_shim, false);
1667 body.coroutine.as_mut().unwrap().coroutine_drop = Some(drop_shim);
1668
1669 let mut proxy_shim = create_coroutine_drop_shim_proxy_async(tcx, body);
1671 deref_finder(tcx, &mut proxy_shim, false);
1672 body.coroutine.as_mut().unwrap().coroutine_drop_proxy_async = Some(proxy_shim);
1673 }
1674
1675 create_coroutine_resume_function(tcx, transform, body, can_return, can_unwind);
1677
1678 deref_finder(tcx, body, false);
1680 }
1681
1682 fn is_required(&self) -> bool {
1683 true
1684 }
1685}
1686
1687struct EnsureCoroutineFieldAssignmentsNeverAlias<'a> {
1700 saved_locals: &'a CoroutineSavedLocals,
1701 storage_conflicts: &'a BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
1702 assigned_local: Option<CoroutineSavedLocal>,
1703}
1704
1705impl EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1706 fn saved_local_for_direct_place(&self, place: Place<'_>) -> Option<CoroutineSavedLocal> {
1707 if place.is_indirect() {
1708 return None;
1709 }
1710
1711 self.saved_locals.get(place.local)
1712 }
1713
1714 fn check_assigned_place(&mut self, place: Place<'_>, f: impl FnOnce(&mut Self)) {
1715 if let Some(assigned_local) = self.saved_local_for_direct_place(place) {
1716 assert!(self.assigned_local.is_none(), "`check_assigned_place` must not recurse");
1717
1718 self.assigned_local = Some(assigned_local);
1719 f(self);
1720 self.assigned_local = None;
1721 }
1722 }
1723}
1724
1725impl<'tcx> Visitor<'tcx> for EnsureCoroutineFieldAssignmentsNeverAlias<'_> {
1726 fn visit_place(&mut self, place: &Place<'tcx>, context: PlaceContext, location: Location) {
1727 let Some(lhs) = self.assigned_local else {
1728 assert!(!context.is_use());
1733 return;
1734 };
1735
1736 let Some(rhs) = self.saved_local_for_direct_place(*place) else { return };
1737
1738 if !self.storage_conflicts.contains(lhs, rhs) {
1739 bug!(
1740 "Assignment between coroutine saved locals whose storage is not \
1741 marked as conflicting: {:?}: {:?} = {:?}",
1742 location,
1743 lhs,
1744 rhs,
1745 );
1746 }
1747 }
1748
1749 fn visit_statement(&mut self, statement: &Statement<'tcx>, location: Location) {
1750 match &statement.kind {
1751 StatementKind::Assign(box (lhs, rhs)) => {
1752 self.check_assigned_place(*lhs, |this| this.visit_rvalue(rhs, location));
1753 }
1754
1755 StatementKind::FakeRead(..)
1756 | StatementKind::SetDiscriminant { .. }
1757 | StatementKind::StorageLive(_)
1758 | StatementKind::StorageDead(_)
1759 | StatementKind::Retag(..)
1760 | StatementKind::AscribeUserType(..)
1761 | StatementKind::PlaceMention(..)
1762 | StatementKind::Coverage(..)
1763 | StatementKind::Intrinsic(..)
1764 | StatementKind::ConstEvalCounter
1765 | StatementKind::BackwardIncompatibleDropHint { .. }
1766 | StatementKind::Nop => {}
1767 }
1768 }
1769
1770 fn visit_terminator(&mut self, terminator: &Terminator<'tcx>, location: Location) {
1771 match &terminator.kind {
1774 TerminatorKind::Call {
1775 func,
1776 args,
1777 destination,
1778 target: Some(_),
1779 unwind: _,
1780 call_source: _,
1781 fn_span: _,
1782 } => {
1783 self.check_assigned_place(*destination, |this| {
1784 this.visit_operand(func, location);
1785 for arg in args {
1786 this.visit_operand(&arg.node, location);
1787 }
1788 });
1789 }
1790
1791 TerminatorKind::Yield { value, resume: _, resume_arg, drop: _ } => {
1792 self.check_assigned_place(*resume_arg, |this| this.visit_operand(value, location));
1793 }
1794
1795 TerminatorKind::InlineAsm { .. } => {}
1797
1798 TerminatorKind::Call { .. }
1799 | TerminatorKind::Goto { .. }
1800 | TerminatorKind::SwitchInt { .. }
1801 | TerminatorKind::UnwindResume
1802 | TerminatorKind::UnwindTerminate(_)
1803 | TerminatorKind::Return
1804 | TerminatorKind::TailCall { .. }
1805 | TerminatorKind::Unreachable
1806 | TerminatorKind::Drop { .. }
1807 | TerminatorKind::Assert { .. }
1808 | TerminatorKind::CoroutineDrop
1809 | TerminatorKind::FalseEdge { .. }
1810 | TerminatorKind::FalseUnwind { .. } => {}
1811 }
1812 }
1813}
1814
1815fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &CoroutineLayout<'tcx>, body: &Body<'tcx>) {
1816 let mut linted_tys = FxHashSet::default();
1817
1818 for (variant, yield_source_info) in
1819 layout.variant_fields.iter().zip(&layout.variant_source_info)
1820 {
1821 debug!(?variant);
1822 for &local in variant {
1823 let decl = &layout.field_tys[local];
1824 debug!(?decl);
1825
1826 if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
1827 let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else {
1828 continue;
1829 };
1830
1831 check_must_not_suspend_ty(
1832 tcx,
1833 decl.ty,
1834 hir_id,
1835 SuspendCheckData {
1836 source_span: decl.source_info.span,
1837 yield_span: yield_source_info.span,
1838 plural_len: 1,
1839 ..Default::default()
1840 },
1841 );
1842 }
1843 }
1844 }
1845}
1846
1847#[derive(Default)]
1848struct SuspendCheckData<'a> {
1849 source_span: Span,
1850 yield_span: Span,
1851 descr_pre: &'a str,
1852 descr_post: &'a str,
1853 plural_len: usize,
1854}
1855
1856fn check_must_not_suspend_ty<'tcx>(
1863 tcx: TyCtxt<'tcx>,
1864 ty: Ty<'tcx>,
1865 hir_id: hir::HirId,
1866 data: SuspendCheckData<'_>,
1867) -> bool {
1868 if ty.is_unit() {
1869 return false;
1870 }
1871
1872 let plural_suffix = pluralize!(data.plural_len);
1873
1874 debug!("Checking must_not_suspend for {}", ty);
1875
1876 match *ty.kind() {
1877 ty::Adt(_, args) if ty.is_box() => {
1878 let boxed_ty = args.type_at(0);
1879 let allocator_ty = args.type_at(1);
1880 check_must_not_suspend_ty(
1881 tcx,
1882 boxed_ty,
1883 hir_id,
1884 SuspendCheckData { descr_pre: &format!("{}boxed ", data.descr_pre), ..data },
1885 ) || check_must_not_suspend_ty(
1886 tcx,
1887 allocator_ty,
1888 hir_id,
1889 SuspendCheckData { descr_pre: &format!("{}allocator ", data.descr_pre), ..data },
1890 )
1891 }
1892 ty::Adt(def, _) if def.repr().scalable() => {
1898 tcx.dcx()
1899 .span_err(data.source_span, "scalable vectors cannot be held over await points");
1900 true
1901 }
1902 ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
1903 ty::Alias(ty::Opaque, ty::AliasTy { def_id: def, .. }) => {
1905 let mut has_emitted = false;
1906 for &(predicate, _) in tcx.explicit_item_bounds(def).skip_binder() {
1907 if let ty::ClauseKind::Trait(ref poly_trait_predicate) =
1909 predicate.kind().skip_binder()
1910 {
1911 let def_id = poly_trait_predicate.trait_ref.def_id;
1912 let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
1913 if check_must_not_suspend_def(
1914 tcx,
1915 def_id,
1916 hir_id,
1917 SuspendCheckData { descr_pre, ..data },
1918 ) {
1919 has_emitted = true;
1920 break;
1921 }
1922 }
1923 }
1924 has_emitted
1925 }
1926 ty::Dynamic(binder, _) => {
1927 let mut has_emitted = false;
1928 for predicate in binder.iter() {
1929 if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
1930 let def_id = trait_ref.def_id;
1931 let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
1932 if check_must_not_suspend_def(
1933 tcx,
1934 def_id,
1935 hir_id,
1936 SuspendCheckData { descr_post, ..data },
1937 ) {
1938 has_emitted = true;
1939 break;
1940 }
1941 }
1942 }
1943 has_emitted
1944 }
1945 ty::Tuple(fields) => {
1946 let mut has_emitted = false;
1947 for (i, ty) in fields.iter().enumerate() {
1948 let descr_post = &format!(" in tuple element {i}");
1949 if check_must_not_suspend_ty(
1950 tcx,
1951 ty,
1952 hir_id,
1953 SuspendCheckData { descr_post, ..data },
1954 ) {
1955 has_emitted = true;
1956 }
1957 }
1958 has_emitted
1959 }
1960 ty::Array(ty, len) => {
1961 let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
1962 check_must_not_suspend_ty(
1963 tcx,
1964 ty,
1965 hir_id,
1966 SuspendCheckData {
1967 descr_pre,
1968 plural_len: len.try_to_target_usize(tcx).unwrap_or(0) as usize + 1,
1970 ..data
1971 },
1972 )
1973 }
1974 ty::Ref(_region, ty, _mutability) => {
1977 let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
1978 check_must_not_suspend_ty(tcx, ty, hir_id, SuspendCheckData { descr_pre, ..data })
1979 }
1980 _ => false,
1981 }
1982}
1983
1984fn check_must_not_suspend_def(
1985 tcx: TyCtxt<'_>,
1986 def_id: DefId,
1987 hir_id: hir::HirId,
1988 data: SuspendCheckData<'_>,
1989) -> bool {
1990 if let Some(reason_str) = find_attr!(tcx, def_id, MustNotSupend {reason} => reason) {
1991 let reason =
1992 reason_str.map(|s| errors::MustNotSuspendReason { span: data.source_span, reason: s });
1993 tcx.emit_node_span_lint(
1994 rustc_session::lint::builtin::MUST_NOT_SUSPEND,
1995 hir_id,
1996 data.source_span,
1997 errors::MustNotSupend {
1998 tcx,
1999 yield_sp: data.yield_span,
2000 reason,
2001 src_sp: data.source_span,
2002 pre: data.descr_pre,
2003 def_id,
2004 post: data.descr_post,
2005 },
2006 );
2007
2008 true
2009 } else {
2010 false
2011 }
2012}