1use std::iter;
2
3use rustc_abi::{FIRST_VARIANT, VariantIdx};
4use rustc_errors::ErrorGuaranteed;
5use rustc_hir::def::DefKind;
6use rustc_hir::def_id::LocalDefId;
7use rustc_middle::mir::interpret::LitToConstInput;
8use rustc_middle::query::Providers;
9use rustc_middle::thir::visit;
10use rustc_middle::thir::visit::Visitor;
11use rustc_middle::ty::abstract_const::CastKind;
12use rustc_middle::ty::{self, Expr, TyCtxt, TypeVisitableExt};
13use rustc_middle::{bug, mir, thir};
14use rustc_span::Span;
15use tracing::{debug, instrument};
16
17use crate::errors::{GenericConstantTooComplex, GenericConstantTooComplexSub};
18
19fn destructure_const<'tcx>(
22 tcx: TyCtxt<'tcx>,
23 const_: ty::Const<'tcx>,
24) -> ty::DestructuredConst<'tcx> {
25 let ty::ConstKind::Value(cv) = const_.kind() else {
26 bug!("cannot destructure constant {:?}", const_)
27 };
28
29 let branches = cv.valtree.unwrap_branch();
30
31 let (fields, variant) = match cv.ty.kind() {
32 ty::Array(inner_ty, _) | ty::Slice(inner_ty) => {
33 let field_consts = branches
35 .iter()
36 .map(|b| ty::Const::new_value(tcx, *b, *inner_ty))
37 .collect::<Vec<_>>();
38 debug!(?field_consts);
39
40 (field_consts, None)
41 }
42 ty::Adt(def, _) if def.variants().is_empty() => bug!("unreachable"),
43 ty::Adt(def, args) => {
44 let (variant_idx, branches) = if def.is_enum() {
45 let (head, rest) = branches.split_first().unwrap();
46 (VariantIdx::from_u32(head.unwrap_leaf().to_u32()), rest)
47 } else {
48 (FIRST_VARIANT, branches)
49 };
50 let fields = &def.variant(variant_idx).fields;
51 let mut field_consts = Vec::with_capacity(fields.len());
52
53 for (field, field_valtree) in iter::zip(fields, branches) {
54 let field_ty = field.ty(tcx, args);
55 let field_const = ty::Const::new_value(tcx, *field_valtree, field_ty);
56 field_consts.push(field_const);
57 }
58 debug!(?field_consts);
59
60 (field_consts, Some(variant_idx))
61 }
62 ty::Tuple(elem_tys) => {
63 let fields = iter::zip(*elem_tys, branches)
64 .map(|(elem_ty, elem_valtree)| ty::Const::new_value(tcx, *elem_valtree, elem_ty))
65 .collect::<Vec<_>>();
66
67 (fields, None)
68 }
69 _ => bug!("cannot destructure constant {:?}", const_),
70 };
71
72 let fields = tcx.arena.alloc_from_iter(fields);
73
74 ty::DestructuredConst { variant, fields }
75}
76
77fn check_binop(op: mir::BinOp) -> bool {
79 use mir::BinOp::*;
80 match op {
81 Add | AddUnchecked | AddWithOverflow | Sub | SubUnchecked | SubWithOverflow | Mul
82 | MulUnchecked | MulWithOverflow | Div | Rem | BitXor | BitAnd | BitOr | Shl
83 | ShlUnchecked | Shr | ShrUnchecked | Eq | Lt | Le | Ne | Ge | Gt | Cmp => true,
84 Offset => false,
85 }
86}
87
88fn check_unop(op: mir::UnOp) -> bool {
91 use mir::UnOp::*;
92 match op {
93 Not | Neg | PtrMetadata => true,
94 }
95}
96
97fn recurse_build<'tcx>(
98 tcx: TyCtxt<'tcx>,
99 body: &thir::Thir<'tcx>,
100 node: thir::ExprId,
101 root_span: Span,
102) -> Result<ty::Const<'tcx>, ErrorGuaranteed> {
103 use thir::ExprKind;
104 let node = &body.exprs[node];
105
106 let maybe_supported_error = |a| maybe_supported_error(tcx, a, root_span);
107 let error = |a| error(tcx, a, root_span);
108
109 Ok(match &node.kind {
110 &ExprKind::Scope { value, .. } => recurse_build(tcx, body, value, root_span)?,
112 &ExprKind::PlaceTypeAscription { source, .. }
113 | &ExprKind::ValueTypeAscription { source, .. } => {
114 recurse_build(tcx, body, source, root_span)?
115 }
116 &ExprKind::PlaceUnwrapUnsafeBinder { .. }
117 | &ExprKind::ValueUnwrapUnsafeBinder { .. }
118 | &ExprKind::WrapUnsafeBinder { .. } => {
119 todo!("FIXME(unsafe_binders)")
120 }
121 &ExprKind::Literal { lit, neg } => {
122 let sp = node.span;
123 tcx.at(sp).lit_to_const(LitToConstInput { lit: &lit.node, ty: node.ty, neg })
124 }
125 &ExprKind::NonHirLiteral { lit, user_ty: _ } => {
126 let val = ty::ValTree::from_scalar_int(tcx, lit);
127 ty::Const::new_value(tcx, val, node.ty)
128 }
129 &ExprKind::ZstLiteral { user_ty: _ } => ty::Const::zero_sized(tcx, node.ty),
130 &ExprKind::NamedConst { def_id, args, user_ty: _ } => {
131 let uneval = ty::UnevaluatedConst::new(def_id, args);
132 ty::Const::new_unevaluated(tcx, uneval)
133 }
134 ExprKind::ConstParam { param, .. } => ty::Const::new_param(tcx, *param),
135
136 ExprKind::Call { fun, args, .. } => {
137 let fun_ty = body.exprs[*fun].ty;
138 let fun = recurse_build(tcx, body, *fun, root_span)?;
139
140 let mut new_args = Vec::<ty::Const<'tcx>>::with_capacity(args.len());
141 for &id in args.iter() {
142 new_args.push(recurse_build(tcx, body, id, root_span)?);
143 }
144 ty::Const::new_expr(tcx, Expr::new_call(tcx, fun_ty, fun, new_args))
145 }
146 &ExprKind::Binary { op, lhs, rhs } if check_binop(op) => {
147 let lhs_ty = body.exprs[lhs].ty;
148 let lhs = recurse_build(tcx, body, lhs, root_span)?;
149 let rhs_ty = body.exprs[rhs].ty;
150 let rhs = recurse_build(tcx, body, rhs, root_span)?;
151 ty::Const::new_expr(tcx, Expr::new_binop(tcx, op, lhs_ty, rhs_ty, lhs, rhs))
152 }
153 &ExprKind::Unary { op, arg } if check_unop(op) => {
154 let arg_ty = body.exprs[arg].ty;
155 let arg = recurse_build(tcx, body, arg, root_span)?;
156 ty::Const::new_expr(tcx, Expr::new_unop(tcx, op, arg_ty, arg))
157 }
158 ExprKind::Block { block } => {
166 if let thir::Block { stmts: box [], expr: Some(e), .. } = &body.blocks[*block] {
167 recurse_build(tcx, body, *e, root_span)?
168 } else {
169 maybe_supported_error(GenericConstantTooComplexSub::BlockNotSupported(node.span))?
170 }
171 }
172 &ExprKind::Use { source } => {
176 let value_ty = body.exprs[source].ty;
177 let value = recurse_build(tcx, body, source, root_span)?;
178 ty::Const::new_expr(tcx, Expr::new_cast(tcx, CastKind::Use, value_ty, value, node.ty))
179 }
180 &ExprKind::Cast { source } => {
181 let value_ty = body.exprs[source].ty;
182 let value = recurse_build(tcx, body, source, root_span)?;
183 ty::Const::new_expr(tcx, Expr::new_cast(tcx, CastKind::As, value_ty, value, node.ty))
184 }
185 ExprKind::Borrow { arg, .. } => {
186 let arg_node = &body.exprs[*arg];
187
188 if let ExprKind::Deref { arg } = arg_node.kind {
192 recurse_build(tcx, body, arg, root_span)?
193 } else {
194 maybe_supported_error(GenericConstantTooComplexSub::BorrowNotSupported(node.span))?
195 }
196 }
197 ExprKind::RawBorrow { .. } | ExprKind::Deref { .. } => maybe_supported_error(
199 GenericConstantTooComplexSub::AddressAndDerefNotSupported(node.span),
200 )?,
201 ExprKind::Repeat { .. } | ExprKind::Array { .. } => {
202 maybe_supported_error(GenericConstantTooComplexSub::ArrayNotSupported(node.span))?
203 }
204 ExprKind::NeverToAny { .. } => {
205 maybe_supported_error(GenericConstantTooComplexSub::NeverToAnyNotSupported(node.span))?
206 }
207 ExprKind::Tuple { .. } => {
208 maybe_supported_error(GenericConstantTooComplexSub::TupleNotSupported(node.span))?
209 }
210 ExprKind::Index { .. } => {
211 maybe_supported_error(GenericConstantTooComplexSub::IndexNotSupported(node.span))?
212 }
213 ExprKind::Field { .. } => {
214 maybe_supported_error(GenericConstantTooComplexSub::FieldNotSupported(node.span))?
215 }
216 ExprKind::ConstBlock { .. } => {
217 maybe_supported_error(GenericConstantTooComplexSub::ConstBlockNotSupported(node.span))?
218 }
219 ExprKind::Adt(_) => {
220 maybe_supported_error(GenericConstantTooComplexSub::AdtNotSupported(node.span))?
221 }
222 ExprKind::PointerCoercion { .. } => {
224 error(GenericConstantTooComplexSub::PointerNotSupported(node.span))?
225 }
226 ExprKind::Yield { .. } => {
227 error(GenericConstantTooComplexSub::YieldNotSupported(node.span))?
228 }
229 ExprKind::Continue { .. } | ExprKind::Break { .. } | ExprKind::Loop { .. } => {
230 error(GenericConstantTooComplexSub::LoopNotSupported(node.span))?
231 }
232 ExprKind::Box { .. } => error(GenericConstantTooComplexSub::BoxNotSupported(node.span))?,
233 ExprKind::ByUse { .. } => {
234 error(GenericConstantTooComplexSub::ByUseNotSupported(node.span))?
235 }
236 ExprKind::Unary { .. } => unreachable!(),
237 ExprKind::Binary { .. } => {
239 error(GenericConstantTooComplexSub::BinaryNotSupported(node.span))?
240 }
241 ExprKind::LogicalOp { .. } => {
242 error(GenericConstantTooComplexSub::LogicalOpNotSupported(node.span))?
243 }
244 ExprKind::Assign { .. } | ExprKind::AssignOp { .. } => {
245 error(GenericConstantTooComplexSub::AssignNotSupported(node.span))?
246 }
247 ExprKind::Closure { .. } | ExprKind::Return { .. } | ExprKind::Become { .. } => {
249 error(GenericConstantTooComplexSub::ClosureAndReturnNotSupported(node.span))?
250 }
251 ExprKind::Match { .. } | ExprKind::If { .. } | ExprKind::Let { .. } => {
253 error(GenericConstantTooComplexSub::ControlFlowNotSupported(node.span))?
254 }
255 ExprKind::InlineAsm { .. } => {
256 error(GenericConstantTooComplexSub::InlineAsmNotSupported(node.span))?
257 }
258
259 ExprKind::VarRef { .. }
261 | ExprKind::UpvarRef { .. }
262 | ExprKind::StaticRef { .. }
263 | ExprKind::OffsetOf { .. }
264 | ExprKind::ThreadLocalRef(_) => {
265 error(GenericConstantTooComplexSub::OperationNotSupported(node.span))?
266 }
267 })
268}
269
270struct IsThirPolymorphic<'a, 'tcx> {
271 is_poly: bool,
272 thir: &'a thir::Thir<'tcx>,
273}
274
275fn error(
276 tcx: TyCtxt<'_>,
277 sub: GenericConstantTooComplexSub,
278 root_span: Span,
279) -> Result<!, ErrorGuaranteed> {
280 let reported = tcx.dcx().emit_err(GenericConstantTooComplex {
281 span: root_span,
282 maybe_supported: false,
283 sub,
284 });
285
286 Err(reported)
287}
288
289fn maybe_supported_error(
290 tcx: TyCtxt<'_>,
291 sub: GenericConstantTooComplexSub,
292 root_span: Span,
293) -> Result<!, ErrorGuaranteed> {
294 let reported = tcx.dcx().emit_err(GenericConstantTooComplex {
295 span: root_span,
296 maybe_supported: true,
297 sub,
298 });
299
300 Err(reported)
301}
302
303impl<'a, 'tcx> IsThirPolymorphic<'a, 'tcx> {
304 fn expr_is_poly(&mut self, expr: &thir::Expr<'tcx>) -> bool {
305 if expr.ty.has_non_region_param() {
306 return true;
307 }
308
309 match expr.kind {
310 thir::ExprKind::NamedConst { args, .. } | thir::ExprKind::ConstBlock { args, .. } => {
311 args.has_non_region_param()
312 }
313 thir::ExprKind::ConstParam { .. } => true,
314 thir::ExprKind::Repeat { value, count } => {
315 self.visit_expr(&self.thir()[value]);
316 count.has_non_region_param()
317 }
318 thir::ExprKind::Scope { .. }
319 | thir::ExprKind::Box { .. }
320 | thir::ExprKind::If { .. }
321 | thir::ExprKind::Call { .. }
322 | thir::ExprKind::ByUse { .. }
323 | thir::ExprKind::Deref { .. }
324 | thir::ExprKind::Binary { .. }
325 | thir::ExprKind::LogicalOp { .. }
326 | thir::ExprKind::Unary { .. }
327 | thir::ExprKind::Cast { .. }
328 | thir::ExprKind::Use { .. }
329 | thir::ExprKind::NeverToAny { .. }
330 | thir::ExprKind::PointerCoercion { .. }
331 | thir::ExprKind::Loop { .. }
332 | thir::ExprKind::Let { .. }
333 | thir::ExprKind::Match { .. }
334 | thir::ExprKind::Block { .. }
335 | thir::ExprKind::Assign { .. }
336 | thir::ExprKind::AssignOp { .. }
337 | thir::ExprKind::Field { .. }
338 | thir::ExprKind::Index { .. }
339 | thir::ExprKind::VarRef { .. }
340 | thir::ExprKind::UpvarRef { .. }
341 | thir::ExprKind::Borrow { .. }
342 | thir::ExprKind::RawBorrow { .. }
343 | thir::ExprKind::Break { .. }
344 | thir::ExprKind::Continue { .. }
345 | thir::ExprKind::Return { .. }
346 | thir::ExprKind::Become { .. }
347 | thir::ExprKind::Array { .. }
348 | thir::ExprKind::Tuple { .. }
349 | thir::ExprKind::Adt(_)
350 | thir::ExprKind::PlaceTypeAscription { .. }
351 | thir::ExprKind::ValueTypeAscription { .. }
352 | thir::ExprKind::PlaceUnwrapUnsafeBinder { .. }
353 | thir::ExprKind::ValueUnwrapUnsafeBinder { .. }
354 | thir::ExprKind::WrapUnsafeBinder { .. }
355 | thir::ExprKind::Closure(_)
356 | thir::ExprKind::Literal { .. }
357 | thir::ExprKind::NonHirLiteral { .. }
358 | thir::ExprKind::ZstLiteral { .. }
359 | thir::ExprKind::StaticRef { .. }
360 | thir::ExprKind::InlineAsm(_)
361 | thir::ExprKind::OffsetOf { .. }
362 | thir::ExprKind::ThreadLocalRef(_)
363 | thir::ExprKind::Yield { .. } => false,
364 }
365 }
366 fn pat_is_poly(&mut self, pat: &thir::Pat<'tcx>) -> bool {
367 if pat.ty.has_non_region_param() {
368 return true;
369 }
370
371 match pat.kind {
372 thir::PatKind::Constant { value } => value.has_non_region_param(),
373 thir::PatKind::Range(ref range) => {
374 let &thir::PatRange { lo, hi, .. } = range.as_ref();
375 lo.has_non_region_param() || hi.has_non_region_param()
376 }
377 _ => false,
378 }
379 }
380}
381
382impl<'a, 'tcx> visit::Visitor<'a, 'tcx> for IsThirPolymorphic<'a, 'tcx> {
383 fn thir(&self) -> &'a thir::Thir<'tcx> {
384 self.thir
385 }
386
387 #[instrument(skip(self), level = "debug")]
388 fn visit_expr(&mut self, expr: &'a thir::Expr<'tcx>) {
389 self.is_poly |= self.expr_is_poly(expr);
390 if !self.is_poly {
391 visit::walk_expr(self, expr)
392 }
393 }
394
395 #[instrument(skip(self), level = "debug")]
396 fn visit_pat(&mut self, pat: &'a thir::Pat<'tcx>) {
397 self.is_poly |= self.pat_is_poly(pat);
398 if !self.is_poly {
399 visit::walk_pat(self, pat);
400 }
401 }
402}
403
404fn thir_abstract_const<'tcx>(
406 tcx: TyCtxt<'tcx>,
407 def: LocalDefId,
408) -> Result<Option<ty::EarlyBinder<'tcx, ty::Const<'tcx>>>, ErrorGuaranteed> {
409 if !tcx.features().generic_const_exprs() {
410 return Ok(None);
411 }
412
413 match tcx.def_kind(def) {
414 DefKind::AnonConst | DefKind::InlineConst => (),
420 _ => return Ok(None),
421 }
422
423 let body = tcx.thir_body(def)?;
424 let (body, body_id) = (&*body.0.borrow(), body.1);
425
426 let mut is_poly_vis = IsThirPolymorphic { is_poly: false, thir: body };
427 visit::walk_expr(&mut is_poly_vis, &body[body_id]);
428 if !is_poly_vis.is_poly {
429 return Ok(None);
430 }
431
432 let root_span = body.exprs[body_id].span;
433
434 Ok(Some(ty::EarlyBinder::bind(recurse_build(tcx, body, body_id, root_span)?)))
435}
436
437pub(crate) fn provide(providers: &mut Providers) {
438 *providers = Providers { destructure_const, thir_abstract_const, ..*providers };
439}