rustc_mir_transform/shim/
async_destructor_ctor.rs

1use std::iter;
2
3use itertools::Itertools;
4use rustc_abi::{FieldIdx, VariantIdx};
5use rustc_const_eval::interpret;
6use rustc_hir::def_id::DefId;
7use rustc_hir::lang_items::LangItem;
8use rustc_index::{Idx, IndexVec};
9use rustc_middle::mir::*;
10use rustc_middle::ty::adjustment::PointerCoercion;
11use rustc_middle::ty::util::{AsyncDropGlueMorphology, Discr};
12use rustc_middle::ty::{self, Ty, TyCtxt};
13use rustc_middle::{bug, span_bug};
14use rustc_span::source_map::respan;
15use rustc_span::{Span, Symbol};
16use rustc_target::spec::PanicStrategy;
17use tracing::debug;
18
19use super::{local_decls_for_sig, new_body};
20
21pub(super) fn build_async_destructor_ctor_shim<'tcx>(
22    tcx: TyCtxt<'tcx>,
23    def_id: DefId,
24    ty: Option<Ty<'tcx>>,
25) -> Body<'tcx> {
26    debug!("build_drop_shim(def_id={:?}, ty={:?})", def_id, ty);
27
28    AsyncDestructorCtorShimBuilder::new(tcx, def_id, ty).build()
29}
30
31/// Builder for async_drop_in_place shim. Functions as a stack machine
32/// to build up an expression using combinators. Stack contains pairs
33/// of locals and types. Combinator is a not yet instantiated pair of a
34/// function and a type, is considered to be an operator which consumes
35/// operands from the stack by instantiating its function and its type
36/// with operand types and moving locals into the function call. Top
37/// pair is considered to be the last operand.
38// FIXME: add mir-opt tests
39struct AsyncDestructorCtorShimBuilder<'tcx> {
40    tcx: TyCtxt<'tcx>,
41    def_id: DefId,
42    self_ty: Option<Ty<'tcx>>,
43    span: Span,
44    source_info: SourceInfo,
45    typing_env: ty::TypingEnv<'tcx>,
46
47    stack: Vec<Operand<'tcx>>,
48    last_bb: BasicBlock,
49    top_cleanup_bb: Option<BasicBlock>,
50
51    locals: IndexVec<Local, LocalDecl<'tcx>>,
52    bbs: IndexVec<BasicBlock, BasicBlockData<'tcx>>,
53}
54
55#[derive(Clone, Copy)]
56enum SurfaceDropKind {
57    Async,
58    Sync,
59}
60
61impl<'tcx> AsyncDestructorCtorShimBuilder<'tcx> {
62    const SELF_PTR: Local = Local::from_u32(1);
63    const INPUT_COUNT: usize = 1;
64    const MAX_STACK_LEN: usize = 2;
65
66    fn new(tcx: TyCtxt<'tcx>, def_id: DefId, self_ty: Option<Ty<'tcx>>) -> Self {
67        let args = if let Some(ty) = self_ty {
68            tcx.mk_args(&[ty.into()])
69        } else {
70            ty::GenericArgs::identity_for_item(tcx, def_id)
71        };
72        let sig = tcx.fn_sig(def_id).instantiate(tcx, args);
73        let sig = tcx.instantiate_bound_regions_with_erased(sig);
74        let span = tcx.def_span(def_id);
75
76        let source_info = SourceInfo::outermost(span);
77
78        debug_assert_eq!(sig.inputs().len(), Self::INPUT_COUNT);
79        let locals = local_decls_for_sig(&sig, span);
80
81        // Usual case: noop() + unwind resume + return
82        let mut bbs = IndexVec::with_capacity(3);
83        let typing_env = ty::TypingEnv::post_analysis(tcx, def_id);
84        AsyncDestructorCtorShimBuilder {
85            tcx,
86            def_id,
87            self_ty,
88            span,
89            source_info,
90            typing_env,
91
92            stack: Vec::with_capacity(Self::MAX_STACK_LEN),
93            last_bb: bbs.push(BasicBlockData::new(None, false)),
94            top_cleanup_bb: match tcx.sess.panic_strategy() {
95                PanicStrategy::Unwind => {
96                    // Don't drop input arg because it's just a pointer
97                    Some(bbs.push(BasicBlockData {
98                        statements: Vec::new(),
99                        terminator: Some(Terminator {
100                            source_info,
101                            kind: TerminatorKind::UnwindResume,
102                        }),
103                        is_cleanup: true,
104                    }))
105                }
106                PanicStrategy::Abort => None,
107            },
108
109            locals,
110            bbs,
111        }
112    }
113
114    fn build(self) -> Body<'tcx> {
115        let (tcx, Some(self_ty)) = (self.tcx, self.self_ty) else {
116            return self.build_zst_output();
117        };
118        match self_ty.async_drop_glue_morphology(tcx) {
119            AsyncDropGlueMorphology::Noop => span_bug!(
120                self.span,
121                "async drop glue shim generator encountered type with noop async drop glue morphology"
122            ),
123            AsyncDropGlueMorphology::DeferredDropInPlace => {
124                return self.build_deferred_drop_in_place();
125            }
126            AsyncDropGlueMorphology::Custom => (),
127        }
128
129        let surface_drop_kind = || {
130            let adt_def = self_ty.ty_adt_def()?;
131            if adt_def.async_destructor(tcx).is_some() {
132                Some(SurfaceDropKind::Async)
133            } else if adt_def.destructor(tcx).is_some() {
134                Some(SurfaceDropKind::Sync)
135            } else {
136                None
137            }
138        };
139
140        match self_ty.kind() {
141            ty::Array(elem_ty, _) => self.build_slice(true, *elem_ty),
142            ty::Slice(elem_ty) => self.build_slice(false, *elem_ty),
143
144            ty::Tuple(elem_tys) => self.build_chain(None, elem_tys.iter()),
145            ty::Adt(adt_def, args) if adt_def.is_struct() => {
146                let field_tys = adt_def.non_enum_variant().fields.iter().map(|f| f.ty(tcx, args));
147                self.build_chain(surface_drop_kind(), field_tys)
148            }
149            ty::Closure(_, args) => self.build_chain(None, args.as_closure().upvar_tys().iter()),
150            ty::CoroutineClosure(_, args) => {
151                self.build_chain(None, args.as_coroutine_closure().upvar_tys().iter())
152            }
153
154            ty::Adt(adt_def, args) if adt_def.is_enum() => {
155                self.build_enum(*adt_def, *args, surface_drop_kind())
156            }
157
158            ty::Adt(adt_def, _) => {
159                assert!(adt_def.is_union());
160                match surface_drop_kind().unwrap() {
161                    SurfaceDropKind::Async => self.build_fused_async_surface(),
162                    SurfaceDropKind::Sync => self.build_fused_sync_surface(),
163                }
164            }
165
166            ty::Bound(..)
167            | ty::Foreign(_)
168            | ty::Placeholder(_)
169            | ty::Infer(ty::FreshTy(_) | ty::FreshIntTy(_) | ty::FreshFloatTy(_) | ty::TyVar(_))
170            | ty::Param(_)
171            | ty::Alias(..) => {
172                bug!("Building async destructor for unexpected type: {self_ty:?}")
173            }
174
175            _ => {
176                bug!(
177                    "Building async destructor constructor shim is not yet implemented for type: {self_ty:?}"
178                )
179            }
180        }
181    }
182
183    fn build_enum(
184        mut self,
185        adt_def: ty::AdtDef<'tcx>,
186        args: ty::GenericArgsRef<'tcx>,
187        surface_drop: Option<SurfaceDropKind>,
188    ) -> Body<'tcx> {
189        let tcx = self.tcx;
190
191        let surface = match surface_drop {
192            None => None,
193            Some(kind) => {
194                self.put_self();
195                Some(match kind {
196                    SurfaceDropKind::Async => self.combine_async_surface(),
197                    SurfaceDropKind::Sync => self.combine_sync_surface(),
198                })
199            }
200        };
201
202        let mut other = None;
203        for (variant_idx, discr) in adt_def.discriminants(tcx) {
204            let variant = adt_def.variant(variant_idx);
205
206            let mut chain = None;
207            for (field_idx, field) in variant.fields.iter_enumerated() {
208                let field_ty = field.ty(tcx, args);
209                self.put_variant_field(variant.name, variant_idx, field_idx, field_ty);
210                let defer = self.combine_defer(field_ty);
211                chain = Some(match chain {
212                    None => defer,
213                    Some(chain) => self.combine_chain(chain, defer),
214                })
215            }
216            let variant_dtor = chain.unwrap_or_else(|| self.put_noop());
217
218            other = Some(match other {
219                None => variant_dtor,
220                Some(other) => {
221                    self.put_self();
222                    self.put_discr(discr);
223                    self.combine_either(other, variant_dtor)
224                }
225            });
226        }
227        let variants_dtor = other.unwrap_or_else(|| self.put_noop());
228
229        let dtor = match surface {
230            None => variants_dtor,
231            Some(surface) => self.combine_chain(surface, variants_dtor),
232        };
233        self.combine_fuse(dtor);
234        self.return_()
235    }
236
237    fn build_chain<I>(mut self, surface_drop: Option<SurfaceDropKind>, elem_tys: I) -> Body<'tcx>
238    where
239        I: Iterator<Item = Ty<'tcx>> + ExactSizeIterator,
240    {
241        let surface = match surface_drop {
242            None => None,
243            Some(kind) => {
244                self.put_self();
245                Some(match kind {
246                    SurfaceDropKind::Async => self.combine_async_surface(),
247                    SurfaceDropKind::Sync => self.combine_sync_surface(),
248                })
249            }
250        };
251
252        let mut chain = None;
253        for (field_idx, field_ty) in elem_tys.enumerate().map(|(i, ty)| (FieldIdx::new(i), ty)) {
254            self.put_field(field_idx, field_ty);
255            let defer = self.combine_defer(field_ty);
256            chain = Some(match chain {
257                None => defer,
258                Some(chain) => self.combine_chain(chain, defer),
259            })
260        }
261        let chain = chain.unwrap_or_else(|| self.put_noop());
262
263        let dtor = match surface {
264            None => chain,
265            Some(surface) => self.combine_chain(surface, chain),
266        };
267        self.combine_fuse(dtor);
268        self.return_()
269    }
270
271    fn build_zst_output(mut self) -> Body<'tcx> {
272        self.put_zst_output();
273        self.return_()
274    }
275
276    fn build_deferred_drop_in_place(mut self) -> Body<'tcx> {
277        self.put_self();
278        let deferred = self.combine_deferred_drop_in_place();
279        self.combine_fuse(deferred);
280        self.return_()
281    }
282
283    fn build_fused_async_surface(mut self) -> Body<'tcx> {
284        self.put_self();
285        let surface = self.combine_async_surface();
286        self.combine_fuse(surface);
287        self.return_()
288    }
289
290    fn build_fused_sync_surface(mut self) -> Body<'tcx> {
291        self.put_self();
292        let surface = self.combine_sync_surface();
293        self.combine_fuse(surface);
294        self.return_()
295    }
296
297    fn build_slice(mut self, is_array: bool, elem_ty: Ty<'tcx>) -> Body<'tcx> {
298        if is_array {
299            self.put_array_as_slice(elem_ty)
300        } else {
301            self.put_self()
302        }
303        let dtor = self.combine_slice(elem_ty);
304        self.combine_fuse(dtor);
305        self.return_()
306    }
307
308    fn put_zst_output(&mut self) {
309        let return_ty = self.locals[RETURN_PLACE].ty;
310        self.put_operand(Operand::Constant(Box::new(ConstOperand {
311            span: self.span,
312            user_ty: None,
313            const_: Const::zero_sized(return_ty),
314        })));
315    }
316
317    /// Puts `to_drop: *mut Self` on top of the stack.
318    fn put_self(&mut self) {
319        self.put_operand(Operand::Copy(Self::SELF_PTR.into()))
320    }
321
322    /// Given that `Self is [ElemTy; N]` puts `to_drop: *mut [ElemTy]`
323    /// on top of the stack.
324    fn put_array_as_slice(&mut self, elem_ty: Ty<'tcx>) {
325        let slice_ptr_ty = Ty::new_mut_ptr(self.tcx, Ty::new_slice(self.tcx, elem_ty));
326        self.put_temp_rvalue(Rvalue::Cast(
327            CastKind::PointerCoercion(PointerCoercion::Unsize, CoercionSource::Implicit),
328            Operand::Copy(Self::SELF_PTR.into()),
329            slice_ptr_ty,
330        ))
331    }
332
333    /// If given Self is a struct puts `to_drop: *mut FieldTy` on top
334    /// of the stack.
335    fn put_field(&mut self, field: FieldIdx, field_ty: Ty<'tcx>) {
336        let place = Place {
337            local: Self::SELF_PTR,
338            projection: self
339                .tcx
340                .mk_place_elems(&[PlaceElem::Deref, PlaceElem::Field(field, field_ty)]),
341        };
342        self.put_temp_rvalue(Rvalue::RawPtr(RawPtrKind::Mut, place))
343    }
344
345    /// If given Self is an enum puts `to_drop: *mut FieldTy` on top of
346    /// the stack.
347    fn put_variant_field(
348        &mut self,
349        variant_sym: Symbol,
350        variant: VariantIdx,
351        field: FieldIdx,
352        field_ty: Ty<'tcx>,
353    ) {
354        let place = Place {
355            local: Self::SELF_PTR,
356            projection: self.tcx.mk_place_elems(&[
357                PlaceElem::Deref,
358                PlaceElem::Downcast(Some(variant_sym), variant),
359                PlaceElem::Field(field, field_ty),
360            ]),
361        };
362        self.put_temp_rvalue(Rvalue::RawPtr(RawPtrKind::Mut, place))
363    }
364
365    /// If given Self is an enum puts `to_drop: *mut FieldTy` on top of
366    /// the stack.
367    fn put_discr(&mut self, discr: Discr<'tcx>) {
368        let (size, _) = discr.ty.int_size_and_signed(self.tcx);
369        self.put_operand(Operand::const_from_scalar(
370            self.tcx,
371            discr.ty,
372            interpret::Scalar::from_uint(discr.val, size),
373            self.span,
374        ));
375    }
376
377    /// Puts `x: RvalueType` on top of the stack.
378    fn put_temp_rvalue(&mut self, rvalue: Rvalue<'tcx>) {
379        let last_bb = &mut self.bbs[self.last_bb];
380        debug_assert!(last_bb.terminator.is_none());
381        let source_info = self.source_info;
382
383        let local_ty = rvalue.ty(&self.locals, self.tcx);
384        // We need to create a new local to be able to "consume" it with
385        // a combinator
386        let local = self.locals.push(LocalDecl::with_source_info(local_ty, source_info));
387        last_bb.statements.extend_from_slice(&[
388            Statement { source_info, kind: StatementKind::StorageLive(local) },
389            Statement {
390                source_info,
391                kind: StatementKind::Assign(Box::new((local.into(), rvalue))),
392            },
393        ]);
394
395        self.put_operand(Operand::Move(local.into()));
396    }
397
398    /// Puts operand on top of the stack.
399    fn put_operand(&mut self, operand: Operand<'tcx>) {
400        if let Some(top_cleanup_bb) = &mut self.top_cleanup_bb {
401            let source_info = self.source_info;
402            match &operand {
403                Operand::Copy(_) | Operand::Constant(_) => {
404                    *top_cleanup_bb = self.bbs.push(BasicBlockData {
405                        statements: Vec::new(),
406                        terminator: Some(Terminator {
407                            source_info,
408                            kind: TerminatorKind::Goto { target: *top_cleanup_bb },
409                        }),
410                        is_cleanup: true,
411                    });
412                }
413                Operand::Move(place) => {
414                    let local = place.as_local().unwrap();
415                    *top_cleanup_bb = self.bbs.push(BasicBlockData {
416                        statements: Vec::new(),
417                        terminator: Some(Terminator {
418                            source_info,
419                            kind: if self.locals[local].ty.needs_drop(self.tcx, self.typing_env) {
420                                TerminatorKind::Drop {
421                                    place: local.into(),
422                                    target: *top_cleanup_bb,
423                                    unwind: UnwindAction::Terminate(
424                                        UnwindTerminateReason::InCleanup,
425                                    ),
426                                    replace: false,
427                                }
428                            } else {
429                                TerminatorKind::Goto { target: *top_cleanup_bb }
430                            },
431                        }),
432                        is_cleanup: true,
433                    });
434                }
435            };
436        }
437        self.stack.push(operand);
438    }
439
440    /// Puts `noop: async_drop::Noop` on top of the stack
441    fn put_noop(&mut self) -> Ty<'tcx> {
442        self.apply_combinator(0, LangItem::AsyncDropNoop, &[])
443    }
444
445    fn combine_async_surface(&mut self) -> Ty<'tcx> {
446        self.apply_combinator(1, LangItem::SurfaceAsyncDropInPlace, &[self.self_ty.unwrap().into()])
447    }
448
449    fn combine_sync_surface(&mut self) -> Ty<'tcx> {
450        self.apply_combinator(
451            1,
452            LangItem::AsyncDropSurfaceDropInPlace,
453            &[self.self_ty.unwrap().into()],
454        )
455    }
456
457    fn combine_deferred_drop_in_place(&mut self) -> Ty<'tcx> {
458        self.apply_combinator(
459            1,
460            LangItem::AsyncDropDeferredDropInPlace,
461            &[self.self_ty.unwrap().into()],
462        )
463    }
464
465    fn combine_fuse(&mut self, inner_future_ty: Ty<'tcx>) -> Ty<'tcx> {
466        self.apply_combinator(1, LangItem::AsyncDropFuse, &[inner_future_ty.into()])
467    }
468
469    fn combine_slice(&mut self, elem_ty: Ty<'tcx>) -> Ty<'tcx> {
470        self.apply_combinator(1, LangItem::AsyncDropSlice, &[elem_ty.into()])
471    }
472
473    fn combine_defer(&mut self, to_drop_ty: Ty<'tcx>) -> Ty<'tcx> {
474        self.apply_combinator(1, LangItem::AsyncDropDefer, &[to_drop_ty.into()])
475    }
476
477    fn combine_chain(&mut self, first: Ty<'tcx>, second: Ty<'tcx>) -> Ty<'tcx> {
478        self.apply_combinator(2, LangItem::AsyncDropChain, &[first.into(), second.into()])
479    }
480
481    fn combine_either(&mut self, other: Ty<'tcx>, matched: Ty<'tcx>) -> Ty<'tcx> {
482        self.apply_combinator(
483            4,
484            LangItem::AsyncDropEither,
485            &[other.into(), matched.into(), self.self_ty.unwrap().into()],
486        )
487    }
488
489    fn return_(mut self) -> Body<'tcx> {
490        let last_bb = &mut self.bbs[self.last_bb];
491        debug_assert!(last_bb.terminator.is_none());
492        let source_info = self.source_info;
493
494        let (1, Some(output)) = (self.stack.len(), self.stack.pop()) else {
495            span_bug!(
496                self.span,
497                "async destructor ctor shim builder finished with invalid number of stack items: expected 1 found {}",
498                self.stack.len(),
499            )
500        };
501        #[cfg(debug_assertions)]
502        if let Some(ty) = self.self_ty {
503            debug_assert_eq!(
504                output.ty(&self.locals, self.tcx),
505                ty.async_destructor_ty(self.tcx),
506                "output async destructor types did not match for type: {ty:?}",
507            );
508        }
509
510        let dead_storage = match &output {
511            Operand::Move(place) => Some(Statement {
512                source_info,
513                kind: StatementKind::StorageDead(place.as_local().unwrap()),
514            }),
515            _ => None,
516        };
517
518        last_bb.statements.extend(
519            iter::once(Statement {
520                source_info,
521                kind: StatementKind::Assign(Box::new((RETURN_PLACE.into(), Rvalue::Use(output)))),
522            })
523            .chain(dead_storage),
524        );
525
526        last_bb.terminator = Some(Terminator { source_info, kind: TerminatorKind::Return });
527
528        let source = MirSource::from_instance(ty::InstanceKind::AsyncDropGlueCtorShim(
529            self.def_id,
530            self.self_ty,
531        ));
532        new_body(source, self.bbs, self.locals, Self::INPUT_COUNT, self.span)
533    }
534
535    fn apply_combinator(
536        &mut self,
537        arity: usize,
538        function: LangItem,
539        args: &[ty::GenericArg<'tcx>],
540    ) -> Ty<'tcx> {
541        let function = self.tcx.require_lang_item(function, Some(self.span));
542        let operands_split = self
543            .stack
544            .len()
545            .checked_sub(arity)
546            .expect("async destructor ctor shim combinator tried to consume too many items");
547        let operands = &self.stack[operands_split..];
548
549        let func_ty = Ty::new_fn_def(self.tcx, function, args.iter().copied());
550        let func_sig = func_ty.fn_sig(self.tcx).no_bound_vars().unwrap();
551        #[cfg(debug_assertions)]
552        operands.iter().zip(func_sig.inputs()).for_each(|(operand, expected_ty)| {
553            let operand_ty = operand.ty(&self.locals, self.tcx);
554            if operand_ty == *expected_ty {
555                return;
556            }
557
558            // If projection of Discriminant then compare with `Ty::discriminant_ty`
559            if let ty::Alias(ty::Projection, ty::AliasTy { args, def_id, .. }) = expected_ty.kind()
560                && self.tcx.is_lang_item(*def_id, LangItem::Discriminant)
561                && args.first().unwrap().as_type().unwrap().discriminant_ty(self.tcx) == operand_ty
562            {
563                return;
564            }
565
566            span_bug!(
567                self.span,
568                "Operand type and combinator argument type are not equal.
569    operand_ty: {:?}
570    argument_ty: {:?}
571",
572                operand_ty,
573                expected_ty
574            );
575        });
576
577        let target = self.bbs.push(BasicBlockData {
578            statements: operands
579                .iter()
580                .rev()
581                .filter_map(|o| {
582                    if let Operand::Move(Place { local, projection }) = o {
583                        assert!(projection.is_empty());
584                        Some(Statement {
585                            source_info: self.source_info,
586                            kind: StatementKind::StorageDead(*local),
587                        })
588                    } else {
589                        None
590                    }
591                })
592                .collect(),
593            terminator: None,
594            is_cleanup: false,
595        });
596
597        let dest_ty = func_sig.output();
598        let dest =
599            self.locals.push(LocalDecl::with_source_info(dest_ty, self.source_info).immutable());
600
601        let unwind = if let Some(top_cleanup_bb) = &mut self.top_cleanup_bb {
602            for _ in 0..arity {
603                *top_cleanup_bb =
604                    self.bbs[*top_cleanup_bb].terminator().successors().exactly_one().ok().unwrap();
605            }
606            UnwindAction::Cleanup(*top_cleanup_bb)
607        } else {
608            UnwindAction::Unreachable
609        };
610
611        let last_bb = &mut self.bbs[self.last_bb];
612        debug_assert!(last_bb.terminator.is_none());
613        last_bb.statements.push(Statement {
614            source_info: self.source_info,
615            kind: StatementKind::StorageLive(dest),
616        });
617        last_bb.terminator = Some(Terminator {
618            source_info: self.source_info,
619            kind: TerminatorKind::Call {
620                func: Operand::Constant(Box::new(ConstOperand {
621                    span: self.span,
622                    user_ty: None,
623                    const_: Const::Val(ConstValue::ZeroSized, func_ty),
624                })),
625                destination: dest.into(),
626                target: Some(target),
627                unwind,
628                call_source: CallSource::Misc,
629                fn_span: self.span,
630                args: self.stack.drain(operands_split..).map(|o| respan(self.span, o)).collect(),
631            },
632        });
633
634        self.put_operand(Operand::Move(dest.into()));
635        self.last_bb = target;
636
637        dest_ty
638    }
639}