1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
//! Deeply normalize types using the old trait solver.

use rustc_data_structures::stack::ensure_sufficient_stack;
use rustc_infer::infer::at::At;
use rustc_infer::infer::InferOk;
use rustc_infer::traits::{
    FromSolverError, Normalized, Obligation, PredicateObligation, TraitEngine,
};
use rustc_macros::extension;
use rustc_middle::traits::{ObligationCause, ObligationCauseCode, Reveal};
use rustc_middle::ty::{
    self, Ty, TyCtxt, TypeFoldable, TypeFolder, TypeSuperFoldable, TypeVisitable, TypeVisitableExt,
};
use tracing::{debug, instrument};

use super::{
    project, with_replaced_escaping_bound_vars, BoundVarReplacer, PlaceholderReplacer,
    SelectionContext,
};
use crate::error_reporting::traits::OverflowCause;
use crate::error_reporting::InferCtxtErrorExt;
use crate::solve::NextSolverError;

#[extension(pub trait NormalizeExt<'tcx>)]
impl<'tcx> At<'_, 'tcx> {
    /// Normalize a value using the `AssocTypeNormalizer`.
    ///
    /// This normalization should be used when the type contains inference variables or the
    /// projection may be fallible.
    fn normalize<T: TypeFoldable<TyCtxt<'tcx>>>(&self, value: T) -> InferOk<'tcx, T> {
        if self.infcx.next_trait_solver() {
            InferOk { value, obligations: Vec::new() }
        } else {
            let mut selcx = SelectionContext::new(self.infcx);
            let Normalized { value, obligations } =
                normalize_with_depth(&mut selcx, self.param_env, self.cause.clone(), 0, value);
            InferOk { value, obligations }
        }
    }

    /// Deeply normalizes `value`, replacing all aliases which can by normalized in
    /// the current environment. In the new solver this errors in case normalization
    /// fails or is ambiguous.
    ///
    /// In the old solver this simply uses `normalizes` and adds the nested obligations
    /// to the `fulfill_cx`. This is necessary as we otherwise end up recomputing the
    /// same goals in both a temporary and the shared context which negatively impacts
    /// performance as these don't share caching.
    ///
    /// FIXME(-Znext-solver): For performance reasons, we currently reuse an existing
    /// fulfillment context in the old solver. Once we have removed the old solver, we
    /// can remove the `fulfill_cx` parameter on this function.
    fn deeply_normalize<T, E>(
        self,
        value: T,
        fulfill_cx: &mut dyn TraitEngine<'tcx, E>,
    ) -> Result<T, Vec<E>>
    where
        T: TypeFoldable<TyCtxt<'tcx>>,
        E: FromSolverError<'tcx, NextSolverError<'tcx>>,
    {
        if self.infcx.next_trait_solver() {
            crate::solve::deeply_normalize(self, value)
        } else {
            let value = self
                .normalize(value)
                .into_value_registering_obligations(self.infcx, &mut *fulfill_cx);
            let errors = fulfill_cx.select_where_possible(self.infcx);
            let value = self.infcx.resolve_vars_if_possible(value);
            if errors.is_empty() { Ok(value) } else { Err(errors) }
        }
    }
}

/// As `normalize`, but with a custom depth.
pub(crate) fn normalize_with_depth<'a, 'b, 'tcx, T>(
    selcx: &'a mut SelectionContext<'b, 'tcx>,
    param_env: ty::ParamEnv<'tcx>,
    cause: ObligationCause<'tcx>,
    depth: usize,
    value: T,
) -> Normalized<'tcx, T>
where
    T: TypeFoldable<TyCtxt<'tcx>>,
{
    let mut obligations = Vec::new();
    let value = normalize_with_depth_to(selcx, param_env, cause, depth, value, &mut obligations);
    Normalized { value, obligations }
}

#[instrument(level = "info", skip(selcx, param_env, cause, obligations))]
pub(crate) fn normalize_with_depth_to<'a, 'b, 'tcx, T>(
    selcx: &'a mut SelectionContext<'b, 'tcx>,
    param_env: ty::ParamEnv<'tcx>,
    cause: ObligationCause<'tcx>,
    depth: usize,
    value: T,
    obligations: &mut Vec<PredicateObligation<'tcx>>,
) -> T
where
    T: TypeFoldable<TyCtxt<'tcx>>,
{
    debug!(obligations.len = obligations.len());
    let mut normalizer = AssocTypeNormalizer::new(selcx, param_env, cause, depth, obligations);
    let result = ensure_sufficient_stack(|| normalizer.fold(value));
    debug!(?result, obligations.len = normalizer.obligations.len());
    debug!(?normalizer.obligations,);
    result
}

pub(super) fn needs_normalization<'tcx, T: TypeVisitable<TyCtxt<'tcx>>>(
    value: &T,
    reveal: Reveal,
) -> bool {
    let mut flags = ty::TypeFlags::HAS_ALIAS;

    // Opaques are treated as rigid with `Reveal::UserFacing`,
    // so we can ignore those.
    match reveal {
        Reveal::UserFacing => flags.remove(ty::TypeFlags::HAS_TY_OPAQUE),
        Reveal::All => {}
    }

    value.has_type_flags(flags)
}

struct AssocTypeNormalizer<'a, 'b, 'tcx> {
    selcx: &'a mut SelectionContext<'b, 'tcx>,
    param_env: ty::ParamEnv<'tcx>,
    cause: ObligationCause<'tcx>,
    obligations: &'a mut Vec<PredicateObligation<'tcx>>,
    depth: usize,
    universes: Vec<Option<ty::UniverseIndex>>,
}

impl<'a, 'b, 'tcx> AssocTypeNormalizer<'a, 'b, 'tcx> {
    fn new(
        selcx: &'a mut SelectionContext<'b, 'tcx>,
        param_env: ty::ParamEnv<'tcx>,
        cause: ObligationCause<'tcx>,
        depth: usize,
        obligations: &'a mut Vec<PredicateObligation<'tcx>>,
    ) -> AssocTypeNormalizer<'a, 'b, 'tcx> {
        debug_assert!(!selcx.infcx.next_trait_solver());
        AssocTypeNormalizer { selcx, param_env, cause, obligations, depth, universes: vec![] }
    }

    fn fold<T: TypeFoldable<TyCtxt<'tcx>>>(&mut self, value: T) -> T {
        let value = self.selcx.infcx.resolve_vars_if_possible(value);
        debug!(?value);

        assert!(
            !value.has_escaping_bound_vars(),
            "Normalizing {value:?} without wrapping in a `Binder`"
        );

        if !needs_normalization(&value, self.param_env.reveal()) {
            value
        } else {
            value.fold_with(self)
        }
    }
}

impl<'a, 'b, 'tcx> TypeFolder<TyCtxt<'tcx>> for AssocTypeNormalizer<'a, 'b, 'tcx> {
    fn cx(&self) -> TyCtxt<'tcx> {
        self.selcx.tcx()
    }

    fn fold_binder<T: TypeFoldable<TyCtxt<'tcx>>>(
        &mut self,
        t: ty::Binder<'tcx, T>,
    ) -> ty::Binder<'tcx, T> {
        self.universes.push(None);
        let t = t.super_fold_with(self);
        self.universes.pop();
        t
    }

    fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
        if !needs_normalization(&ty, self.param_env.reveal()) {
            return ty;
        }

        let (kind, data) = match *ty.kind() {
            ty::Alias(kind, data) => (kind, data),
            _ => return ty.super_fold_with(self),
        };

        // We try to be a little clever here as a performance optimization in
        // cases where there are nested projections under binders.
        // For example:
        // ```
        // for<'a> fn(<T as Foo>::One<'a, Box<dyn Bar<'a, Item=<T as Foo>::Two<'a>>>>)
        // ```
        // We normalize the args on the projection before the projecting, but
        // if we're naive, we'll
        //   replace bound vars on inner, project inner, replace placeholders on inner,
        //   replace bound vars on outer, project outer, replace placeholders on outer
        //
        // However, if we're a bit more clever, we can replace the bound vars
        // on the entire type before normalizing nested projections, meaning we
        //   replace bound vars on outer, project inner,
        //   project outer, replace placeholders on outer
        //
        // This is possible because the inner `'a` will already be a placeholder
        // when we need to normalize the inner projection
        //
        // On the other hand, this does add a bit of complexity, since we only
        // replace bound vars if the current type is a `Projection` and we need
        // to make sure we don't forget to fold the args regardless.

        match kind {
            ty::Opaque => {
                // Only normalize `impl Trait` outside of type inference, usually in codegen.
                match self.param_env.reveal() {
                    Reveal::UserFacing => ty.super_fold_with(self),

                    Reveal::All => {
                        let recursion_limit = self.cx().recursion_limit();
                        if !recursion_limit.value_within_limit(self.depth) {
                            self.selcx.infcx.err_ctxt().report_overflow_error(
                                OverflowCause::DeeplyNormalize(data.into()),
                                self.cause.span,
                                true,
                                |_| {},
                            );
                        }

                        let args = data.args.fold_with(self);
                        let generic_ty = self.cx().type_of(data.def_id);
                        let concrete_ty = generic_ty.instantiate(self.cx(), args);
                        self.depth += 1;
                        let folded_ty = self.fold_ty(concrete_ty);
                        self.depth -= 1;
                        folded_ty
                    }
                }
            }

            ty::Projection if !data.has_escaping_bound_vars() => {
                // This branch is *mostly* just an optimization: when we don't
                // have escaping bound vars, we don't need to replace them with
                // placeholders (see branch below). *Also*, we know that we can
                // register an obligation to *later* project, since we know
                // there won't be bound vars there.
                let data = data.fold_with(self);
                let normalized_ty = project::normalize_projection_ty(
                    self.selcx,
                    self.param_env,
                    data,
                    self.cause.clone(),
                    self.depth,
                    self.obligations,
                );
                debug!(
                    ?self.depth,
                    ?ty,
                    ?normalized_ty,
                    obligations.len = ?self.obligations.len(),
                    "AssocTypeNormalizer: normalized type"
                );
                normalized_ty.expect_type()
            }

            ty::Projection => {
                // If there are escaping bound vars, we temporarily replace the
                // bound vars with placeholders. Note though, that in the case
                // that we still can't project for whatever reason (e.g. self
                // type isn't known enough), we *can't* register an obligation
                // and return an inference variable (since then that obligation
                // would have bound vars and that's a can of worms). Instead,
                // we just give up and fall back to pretending like we never tried!
                //
                // Note: this isn't necessarily the final approach here; we may
                // want to figure out how to register obligations with escaping vars
                // or handle this some other way.

                let infcx = self.selcx.infcx;
                let (data, mapped_regions, mapped_types, mapped_consts) =
                    BoundVarReplacer::replace_bound_vars(infcx, &mut self.universes, data);
                let data = data.fold_with(self);
                let normalized_ty = project::opt_normalize_projection_term(
                    self.selcx,
                    self.param_env,
                    data.into(),
                    self.cause.clone(),
                    self.depth,
                    self.obligations,
                )
                .ok()
                .flatten()
                .map(|term| term.expect_type())
                .map(|normalized_ty| {
                    PlaceholderReplacer::replace_placeholders(
                        infcx,
                        mapped_regions,
                        mapped_types,
                        mapped_consts,
                        &self.universes,
                        normalized_ty,
                    )
                })
                .unwrap_or_else(|| ty.super_fold_with(self));

                debug!(
                    ?self.depth,
                    ?ty,
                    ?normalized_ty,
                    obligations.len = ?self.obligations.len(),
                    "AssocTypeNormalizer: normalized type"
                );
                normalized_ty
            }
            ty::Weak => {
                let recursion_limit = self.cx().recursion_limit();
                if !recursion_limit.value_within_limit(self.depth) {
                    self.selcx.infcx.err_ctxt().report_overflow_error(
                        OverflowCause::DeeplyNormalize(data.into()),
                        self.cause.span,
                        false,
                        |diag| {
                            diag.note(crate::fluent_generated::trait_selection_ty_alias_overflow);
                        },
                    );
                }

                let infcx = self.selcx.infcx;
                self.obligations.extend(
                    infcx.tcx.predicates_of(data.def_id).instantiate_own(infcx.tcx, data.args).map(
                        |(mut predicate, span)| {
                            if data.has_escaping_bound_vars() {
                                (predicate, ..) = BoundVarReplacer::replace_bound_vars(
                                    infcx,
                                    &mut self.universes,
                                    predicate,
                                );
                            }
                            let mut cause = self.cause.clone();
                            cause.map_code(|code| {
                                ObligationCauseCode::TypeAlias(code, span, data.def_id)
                            });
                            Obligation::new(infcx.tcx, cause, self.param_env, predicate)
                        },
                    ),
                );
                self.depth += 1;
                let res = infcx
                    .tcx
                    .type_of(data.def_id)
                    .instantiate(infcx.tcx, data.args)
                    .fold_with(self);
                self.depth -= 1;
                res
            }

            ty::Inherent if !data.has_escaping_bound_vars() => {
                // This branch is *mostly* just an optimization: when we don't
                // have escaping bound vars, we don't need to replace them with
                // placeholders (see branch below). *Also*, we know that we can
                // register an obligation to *later* project, since we know
                // there won't be bound vars there.

                let data = data.fold_with(self);

                project::normalize_inherent_projection(
                    self.selcx,
                    self.param_env,
                    data,
                    self.cause.clone(),
                    self.depth,
                    self.obligations,
                )
            }

            ty::Inherent => {
                let infcx = self.selcx.infcx;
                let (data, mapped_regions, mapped_types, mapped_consts) =
                    BoundVarReplacer::replace_bound_vars(infcx, &mut self.universes, data);
                let data = data.fold_with(self);
                let ty = project::normalize_inherent_projection(
                    self.selcx,
                    self.param_env,
                    data,
                    self.cause.clone(),
                    self.depth,
                    self.obligations,
                );

                PlaceholderReplacer::replace_placeholders(
                    infcx,
                    mapped_regions,
                    mapped_types,
                    mapped_consts,
                    &self.universes,
                    ty,
                )
            }
        }
    }

    #[instrument(skip(self), level = "debug")]
    fn fold_const(&mut self, constant: ty::Const<'tcx>) -> ty::Const<'tcx> {
        let tcx = self.selcx.tcx();
        if tcx.features().generic_const_exprs
            || !needs_normalization(&constant, self.param_env.reveal())
        {
            constant
        } else {
            let constant = constant.super_fold_with(self);
            debug!(?constant, ?self.param_env);
            with_replaced_escaping_bound_vars(
                self.selcx.infcx,
                &mut self.universes,
                constant,
                |constant| constant.normalize(tcx, self.param_env),
            )
        }
    }

    #[inline]
    fn fold_predicate(&mut self, p: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
        if p.allow_normalization() && needs_normalization(&p, self.param_env.reveal()) {
            p.super_fold_with(self)
        } else {
            p
        }
    }
}