1use std::ops;
24
25use itertools::izip;
26use rustc_abi::{FieldIdx, VariantIdx};
27use rustc_data_structures::fx::FxHashSet;
28use rustc_errors::pluralize;
29use rustc_hir::{self as hir, find_attr};
30use rustc_index::bit_set::{BitMatrix, DenseBitSet};
31use rustc_index::{Idx, IndexVec};
32use rustc_middle::mir::*;
33use rustc_middle::span_bug;
34use rustc_middle::ty::{self, CoroutineArgs, CoroutineArgsExt, Ty, TyCtxt, TypingMode};
35use rustc_mir_dataflow::impls::{
36 MaybeBorrowedLocals, MaybeLiveLocals, MaybeRequiresStorage, MaybeStorageLive,
37 always_storage_live_locals,
38};
39use rustc_mir_dataflow::{
40 Analysis, Results, ResultsCursor, ResultsVisitor, visit_reachable_results,
41};
42use rustc_span::Span;
43use rustc_span::def_id::{DefId, LocalDefId};
44use rustc_trait_selection::error_reporting::InferCtxtErrorExt;
45use rustc_trait_selection::infer::TyCtxtInferExt as _;
46use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode, ObligationCtxt};
47use tracing::{debug, instrument, trace};
48
49use crate::errors::{MustNotSupend, MustNotSuspendReason};
50
51const SELF_ARG: Local = Local::arg(0);
52
53pub(super) struct LivenessInfo {
54 pub(super) saved_locals: CoroutineSavedLocals,
56
57 live_locals_at_suspension_points: Vec<DenseBitSet<CoroutineSavedLocal>>,
59
60 source_info_at_suspension_points: Vec<SourceInfo>,
62
63 pub(super) storage_conflicts: BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal>,
67
68 storage_liveness: IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
71}
72
73#[tracing::instrument(level = "trace", skip(tcx, body))]
82pub(super) fn locals_live_across_suspend_points<'tcx>(
83 tcx: TyCtxt<'tcx>,
84 body: &Body<'tcx>,
85 always_live_locals: &DenseBitSet<Local>,
86 movable: bool,
87) -> LivenessInfo {
88 let mut storage_live = MaybeStorageLive::new(std::borrow::Cow::Borrowed(always_live_locals))
91 .iterate_to_fixpoint(tcx, body, None)
92 .into_results_cursor(body);
93
94 let borrowed_locals = MaybeBorrowedLocals.iterate_to_fixpoint(tcx, body, Some("coroutine"));
96 let borrowed_locals_cursor1 = ResultsCursor::new_borrowing(body, &borrowed_locals);
97 let mut borrowed_locals_cursor2 = ResultsCursor::new_borrowing(body, &borrowed_locals);
98
99 let requires_storage =
101 MaybeRequiresStorage::new(borrowed_locals_cursor1).iterate_to_fixpoint(tcx, body, None);
102 let mut requires_storage_cursor = ResultsCursor::new_borrowing(body, &requires_storage);
103
104 let mut liveness =
106 MaybeLiveLocals.iterate_to_fixpoint(tcx, body, Some("coroutine")).into_results_cursor(body);
107
108 let mut storage_liveness_map = IndexVec::from_elem(None, &body.basic_blocks);
109 let mut live_locals_at_suspension_points = Vec::new();
110 let mut source_info_at_suspension_points = Vec::new();
111 let mut live_locals_at_any_suspension_point = DenseBitSet::new_empty(body.local_decls.len());
112
113 for (block, data) in body.basic_blocks.iter_enumerated() {
114 let TerminatorKind::Yield { .. } = data.terminator().kind else { continue };
115
116 let loc = Location { block, statement_index: data.statements.len() };
117
118 liveness.seek_to_block_end(block);
119 let mut live_locals = liveness.get().clone();
120
121 if !movable {
122 borrowed_locals_cursor2.seek_before_primary_effect(loc);
133 live_locals.union(borrowed_locals_cursor2.get());
134 }
135
136 storage_live.seek_before_primary_effect(loc);
139 storage_liveness_map[block] = Some(storage_live.get().clone());
140
141 requires_storage_cursor.seek_before_primary_effect(loc);
145 live_locals.intersect(requires_storage_cursor.get());
146
147 live_locals.remove(SELF_ARG);
149
150 debug!(?loc, ?live_locals);
151
152 live_locals_at_any_suspension_point.union(&live_locals);
155
156 live_locals_at_suspension_points.push(live_locals);
157 source_info_at_suspension_points.push(data.terminator().source_info);
158 }
159
160 debug!(?live_locals_at_any_suspension_point);
161 let saved_locals = CoroutineSavedLocals(live_locals_at_any_suspension_point);
162
163 let live_locals_at_suspension_points = live_locals_at_suspension_points
166 .iter()
167 .map(|live_here| saved_locals.renumber_bitset(live_here))
168 .collect();
169
170 let storage_conflicts = compute_storage_conflicts(
171 body,
172 &saved_locals,
173 always_live_locals.clone(),
174 &requires_storage,
175 );
176
177 LivenessInfo {
178 saved_locals,
179 live_locals_at_suspension_points,
180 source_info_at_suspension_points,
181 storage_conflicts,
182 storage_liveness: storage_liveness_map,
183 }
184}
185
186pub(super) struct CoroutineSavedLocals(DenseBitSet<Local>);
192
193impl CoroutineSavedLocals {
194 fn iter_enumerated(&self) -> impl '_ + Iterator<Item = (CoroutineSavedLocal, Local)> {
197 self.iter().enumerate().map(|(i, l)| (CoroutineSavedLocal::from(i), l))
198 }
199
200 fn renumber_bitset(&self, input: &DenseBitSet<Local>) -> DenseBitSet<CoroutineSavedLocal> {
203 assert!(self.superset(input), "{:?} not a superset of {:?}", self.0, input);
204 let mut out = DenseBitSet::new_empty(self.count());
205 for (saved_local, local) in self.iter_enumerated() {
206 if input.contains(local) {
207 out.insert(saved_local);
208 }
209 }
210 out
211 }
212
213 pub(super) fn get(&self, local: Local) -> Option<CoroutineSavedLocal> {
214 if !self.contains(local) {
215 return None;
216 }
217
218 let idx = self.iter().take_while(|&l| l < local).count();
219 Some(CoroutineSavedLocal::new(idx))
220 }
221}
222
223impl ops::Deref for CoroutineSavedLocals {
224 type Target = DenseBitSet<Local>;
225
226 fn deref(&self) -> &Self::Target {
227 &self.0
228 }
229}
230
231fn compute_storage_conflicts<'mir, 'tcx>(
236 body: &'mir Body<'tcx>,
237 saved_locals: &'mir CoroutineSavedLocals,
238 always_live_locals: DenseBitSet<Local>,
239 results: &Results<'tcx, MaybeRequiresStorage<'mir, 'tcx>>,
240) -> BitMatrix<CoroutineSavedLocal, CoroutineSavedLocal> {
241 assert_eq!(body.local_decls.len(), saved_locals.domain_size());
242
243 debug!("compute_storage_conflicts({:?})", body.span);
244 debug!("always_live = {:?}", always_live_locals);
245
246 let mut ineligible_locals = always_live_locals;
249 ineligible_locals.intersect(&**saved_locals);
250
251 let mut visitor = StorageConflictVisitor {
253 body,
254 saved_locals,
255 local_conflicts: BitMatrix::from_row_n(&ineligible_locals, body.local_decls.len()),
256 eligible_storage_live: DenseBitSet::new_empty(body.local_decls.len()),
257 };
258
259 visit_reachable_results(body, results, &mut visitor);
260
261 let local_conflicts = visitor.local_conflicts;
262
263 let mut storage_conflicts = BitMatrix::new(saved_locals.count(), saved_locals.count());
271 for (saved_local_a, local_a) in saved_locals.iter_enumerated() {
272 if ineligible_locals.contains(local_a) {
273 storage_conflicts.insert_all_into_row(saved_local_a);
275 } else {
276 for (saved_local_b, local_b) in saved_locals.iter_enumerated() {
278 if local_conflicts.contains(local_a, local_b) {
279 storage_conflicts.insert(saved_local_a, saved_local_b);
280 }
281 }
282 }
283 }
284 storage_conflicts
285}
286
287struct StorageConflictVisitor<'a, 'tcx> {
288 body: &'a Body<'tcx>,
289 saved_locals: &'a CoroutineSavedLocals,
290 local_conflicts: BitMatrix<Local, Local>,
293 eligible_storage_live: DenseBitSet<Local>,
295}
296
297impl<'a, 'tcx> ResultsVisitor<'tcx, MaybeRequiresStorage<'a, 'tcx>>
298 for StorageConflictVisitor<'a, 'tcx>
299{
300 fn visit_after_early_statement_effect(
301 &mut self,
302 _analysis: &MaybeRequiresStorage<'a, 'tcx>,
303 state: &DenseBitSet<Local>,
304 _statement: &Statement<'tcx>,
305 loc: Location,
306 ) {
307 self.apply_state(state, loc);
308 }
309
310 fn visit_after_early_terminator_effect(
311 &mut self,
312 _analysis: &MaybeRequiresStorage<'a, 'tcx>,
313 state: &DenseBitSet<Local>,
314 _terminator: &Terminator<'tcx>,
315 loc: Location,
316 ) {
317 self.apply_state(state, loc);
318 }
319}
320
321impl StorageConflictVisitor<'_, '_> {
322 fn apply_state(&mut self, state: &DenseBitSet<Local>, loc: Location) {
323 if let TerminatorKind::Unreachable = self.body.basic_blocks[loc.block].terminator().kind {
325 return;
326 }
327
328 self.eligible_storage_live.clone_from(state);
329 self.eligible_storage_live.intersect(&**self.saved_locals);
330
331 for local in self.eligible_storage_live.iter() {
332 self.local_conflicts.union_row_with(&self.eligible_storage_live, local);
333 }
334
335 if self.eligible_storage_live.count() > 1 {
336 trace!("at {:?}, eligible_storage_live={:?}", loc, self.eligible_storage_live);
337 }
338 }
339}
340
341#[tracing::instrument(level = "trace", skip(liveness, body))]
342pub(super) fn compute_layout<'tcx>(
343 liveness: LivenessInfo,
344 body: &Body<'tcx>,
345) -> (
346 IndexVec<Local, Option<(Ty<'tcx>, VariantIdx, FieldIdx)>>,
347 CoroutineLayout<'tcx>,
348 IndexVec<BasicBlock, Option<DenseBitSet<Local>>>,
349) {
350 let LivenessInfo {
351 saved_locals,
352 live_locals_at_suspension_points,
353 source_info_at_suspension_points,
354 storage_conflicts,
355 storage_liveness,
356 } = liveness;
357
358 let mut tys: IndexVec<CoroutineSavedLocal, CoroutineSavedTy<'_>> = saved_locals
360 .iter_enumerated()
361 .map(|(saved_local, local)| {
362 debug!("coroutine saved local {:?} => {:?}", saved_local, local);
363
364 let decl = &body.local_decls[local];
365
366 let ignore_for_traits = match decl.local_info {
371 ClearCrossCrate::Set(LocalInfo::StaticRef { is_thread_local, .. }) => {
374 !is_thread_local
375 }
376 ClearCrossCrate::Set(LocalInfo::FakeBorrow) => true,
379 _ => false,
380 };
381
382 CoroutineSavedTy {
383 ty: decl.ty,
384 source_info: decl.source_info,
385 ignore_for_traits,
386 debuginfo_name: None,
388 }
389 })
390 .collect();
391
392 let body_span = body.source_scopes[OUTERMOST_SOURCE_SCOPE].span;
396 let mut variant_source_info: IndexVec<VariantIdx, SourceInfo> = IndexVec::with_capacity(
397 CoroutineArgs::RESERVED_VARIANTS + live_locals_at_suspension_points.len(),
398 );
399 variant_source_info.extend([
400 SourceInfo::outermost(body_span.shrink_to_lo()),
401 SourceInfo::outermost(body_span.shrink_to_hi()),
402 SourceInfo::outermost(body_span.shrink_to_hi()),
403 ]);
404
405 let reverse_local_map: IndexVec<CoroutineSavedLocal, Local> = saved_locals.iter().collect();
407
408 let mut variant_fields: IndexVec<VariantIdx, _> = IndexVec::from_elem_n(
411 IndexVec::new(),
412 CoroutineArgs::RESERVED_VARIANTS + live_locals_at_suspension_points.len(),
413 );
414 let mut remap = IndexVec::from_elem_n(None, saved_locals.domain_size());
415 for (live_locals, &source_info_at_suspension_point, (variant_index, fields)) in izip!(
416 &live_locals_at_suspension_points,
417 &source_info_at_suspension_points,
418 variant_fields.iter_enumerated_mut().skip(CoroutineArgs::RESERVED_VARIANTS)
419 ) {
420 *fields = live_locals.iter().collect();
421 for (idx, &saved_local) in fields.iter_enumerated() {
422 remap[reverse_local_map[saved_local]] = Some((tys[saved_local].ty, variant_index, idx));
427 }
428 variant_source_info.push(source_info_at_suspension_point);
429 }
430 debug!(?variant_fields);
431 debug!(?storage_conflicts);
432
433 for var in &body.var_debug_info {
434 let VarDebugInfoContents::Place(place) = &var.value else { continue };
435 let Some(local) = place.as_local() else { continue };
436 let Some(&Some((_, variant, field))) = remap.get(local) else {
437 continue;
438 };
439
440 let saved_local: CoroutineSavedLocal = variant_fields[variant][field];
441 tys[saved_local].debuginfo_name.get_or_insert(var.name);
442 }
443
444 let layout =
445 CoroutineLayout { field_tys: tys, variant_fields, variant_source_info, storage_conflicts };
446 debug!(?remap);
447 debug!(?layout);
448 debug!(?storage_liveness);
449
450 (remap, layout, storage_liveness)
451}
452
453#[instrument(level = "debug", skip(tcx), ret)]
454pub(crate) fn mir_coroutine_witnesses<'tcx>(
455 tcx: TyCtxt<'tcx>,
456 def_id: LocalDefId,
457) -> Option<CoroutineLayout<'tcx>> {
458 let (body, _) = tcx.mir_promoted(def_id);
459 let body = body.borrow();
460 let body = &*body;
461
462 let coroutine_ty = body.local_decls[ty::CAPTURE_STRUCT_LOCAL].ty;
464
465 let movable = match *coroutine_ty.kind() {
466 ty::Coroutine(def_id, _) => tcx.coroutine_movability(def_id) == hir::Movability::Movable,
467 ty::Error(_) => return None,
468 _ => span_bug!(body.span, "unexpected coroutine type {}", coroutine_ty),
469 };
470
471 let always_live_locals = always_storage_live_locals(body);
474 let liveness_info = locals_live_across_suspend_points(tcx, body, &always_live_locals, movable);
475
476 let (_, coroutine_layout, _) = compute_layout(liveness_info, body);
480
481 check_suspend_tys(tcx, &coroutine_layout, body);
482 check_field_tys_sized(tcx, &coroutine_layout, def_id);
483
484 Some(coroutine_layout)
485}
486
487fn check_field_tys_sized<'tcx>(
488 tcx: TyCtxt<'tcx>,
489 coroutine_layout: &CoroutineLayout<'tcx>,
490 def_id: LocalDefId,
491) {
492 if !tcx.features().unsized_fn_params() {
495 return;
496 }
497
498 let infcx = tcx.infer_ctxt().ignoring_regions().build(TypingMode::non_body_analysis());
503 let param_env = tcx.param_env(def_id);
504
505 let ocx = ObligationCtxt::new_with_diagnostics(&infcx);
506 for field_ty in &coroutine_layout.field_tys {
507 ocx.register_bound(
508 ObligationCause::new(
509 field_ty.source_info.span,
510 def_id,
511 ObligationCauseCode::SizedCoroutineInterior(def_id),
512 ),
513 param_env,
514 field_ty.ty,
515 tcx.require_lang_item(hir::LangItem::Sized, field_ty.source_info.span),
516 );
517 }
518
519 let errors = ocx.evaluate_obligations_error_on_ambiguity();
520 debug!(?errors);
521 if !errors.is_empty() {
522 infcx.err_ctxt().report_fulfillment_errors(errors);
523 }
524}
525
526fn check_suspend_tys<'tcx>(tcx: TyCtxt<'tcx>, layout: &CoroutineLayout<'tcx>, body: &Body<'tcx>) {
527 let mut linted_tys = FxHashSet::default();
528
529 for (variant, yield_source_info) in
530 layout.variant_fields.iter().zip(&layout.variant_source_info)
531 {
532 debug!(?variant);
533 for &local in variant {
534 let decl = &layout.field_tys[local];
535 debug!(?decl);
536
537 if !decl.ignore_for_traits && linted_tys.insert(decl.ty) {
538 let Some(hir_id) = decl.source_info.scope.lint_root(&body.source_scopes) else {
539 continue;
540 };
541
542 check_must_not_suspend_ty(
543 tcx,
544 decl.ty,
545 hir_id,
546 SuspendCheckData {
547 source_span: decl.source_info.span,
548 yield_span: yield_source_info.span,
549 plural_len: 1,
550 ..Default::default()
551 },
552 );
553 }
554 }
555 }
556}
557
558#[derive(Default)]
559struct SuspendCheckData<'a> {
560 source_span: Span,
561 yield_span: Span,
562 descr_pre: &'a str,
563 descr_post: &'a str,
564 plural_len: usize,
565}
566
567fn check_must_not_suspend_ty<'tcx>(
574 tcx: TyCtxt<'tcx>,
575 ty: Ty<'tcx>,
576 hir_id: hir::HirId,
577 data: SuspendCheckData<'_>,
578) -> bool {
579 if ty.is_unit() {
580 return false;
581 }
582
583 let plural_suffix = pluralize!(data.plural_len);
584
585 debug!("Checking must_not_suspend for {}", ty);
586
587 match *ty.kind() {
588 ty::Adt(_, args) if ty.is_box() => {
589 let boxed_ty = args.type_at(0);
590 let allocator_ty = args.type_at(1);
591 check_must_not_suspend_ty(
592 tcx,
593 boxed_ty,
594 hir_id,
595 SuspendCheckData { descr_pre: &format!("{}boxed ", data.descr_pre), ..data },
596 ) || check_must_not_suspend_ty(
597 tcx,
598 allocator_ty,
599 hir_id,
600 SuspendCheckData { descr_pre: &format!("{}allocator ", data.descr_pre), ..data },
601 )
602 }
603 ty::Adt(def, _) if def.repr().scalable() => {
609 tcx.dcx()
610 .span_err(data.source_span, "scalable vectors cannot be held over await points");
611 true
612 }
613 ty::Adt(def, _) => check_must_not_suspend_def(tcx, def.did(), hir_id, data),
614 ty::Alias(ty::AliasTy { kind: ty::Opaque { def_id: def }, .. }) => {
616 let mut has_emitted = false;
617 for &(predicate, _) in tcx.explicit_item_bounds(def).skip_binder() {
618 if let ty::ClauseKind::Trait(ref poly_trait_predicate) =
620 predicate.kind().skip_binder()
621 {
622 let def_id = poly_trait_predicate.trait_ref.def_id;
623 let descr_pre = &format!("{}implementer{} of ", data.descr_pre, plural_suffix);
624 if check_must_not_suspend_def(
625 tcx,
626 def_id,
627 hir_id,
628 SuspendCheckData { descr_pre, ..data },
629 ) {
630 has_emitted = true;
631 break;
632 }
633 }
634 }
635 has_emitted
636 }
637 ty::Dynamic(binder, _) => {
638 let mut has_emitted = false;
639 for predicate in binder.iter() {
640 if let ty::ExistentialPredicate::Trait(ref trait_ref) = predicate.skip_binder() {
641 let def_id = trait_ref.def_id;
642 let descr_post = &format!(" trait object{}{}", plural_suffix, data.descr_post);
643 if check_must_not_suspend_def(
644 tcx,
645 def_id,
646 hir_id,
647 SuspendCheckData { descr_post, ..data },
648 ) {
649 has_emitted = true;
650 break;
651 }
652 }
653 }
654 has_emitted
655 }
656 ty::Tuple(fields) => {
657 let mut has_emitted = false;
658 for (i, ty) in fields.iter().enumerate() {
659 let descr_post = &format!(" in tuple element {i}");
660 if check_must_not_suspend_ty(
661 tcx,
662 ty,
663 hir_id,
664 SuspendCheckData { descr_post, ..data },
665 ) {
666 has_emitted = true;
667 }
668 }
669 has_emitted
670 }
671 ty::Array(ty, len) => {
672 let descr_pre = &format!("{}array{} of ", data.descr_pre, plural_suffix);
673 check_must_not_suspend_ty(
674 tcx,
675 ty,
676 hir_id,
677 SuspendCheckData {
678 descr_pre,
679 plural_len: len.try_to_target_usize(tcx).unwrap_or(0) as usize + 1,
681 ..data
682 },
683 )
684 }
685 ty::Ref(_region, ty, _mutability) => {
688 let descr_pre = &format!("{}reference{} to ", data.descr_pre, plural_suffix);
689 check_must_not_suspend_ty(tcx, ty, hir_id, SuspendCheckData { descr_pre, ..data })
690 }
691 _ => false,
692 }
693}
694
695fn check_must_not_suspend_def(
696 tcx: TyCtxt<'_>,
697 def_id: DefId,
698 hir_id: hir::HirId,
699 data: SuspendCheckData<'_>,
700) -> bool {
701 if let Some(reason_str) = find_attr!(tcx, def_id, MustNotSupend {reason} => reason) {
702 let reason = reason_str.map(|s| MustNotSuspendReason { span: data.source_span, reason: s });
703 tcx.emit_node_span_lint(
704 rustc_session::lint::builtin::MUST_NOT_SUSPEND,
705 hir_id,
706 data.source_span,
707 MustNotSupend {
708 tcx,
709 yield_sp: data.yield_span,
710 reason,
711 src_sp: data.source_span,
712 pre: data.descr_pre,
713 def_id,
714 post: data.descr_post,
715 },
716 );
717
718 true
719 } else {
720 false
721 }
722}