1use rustc_ast::{BinOpKind, BorrowKind, Expr, ExprKind, MetaItem, Mutability, Safety};
2use rustc_expand::base::{Annotatable, ExtCtxt};
3use rustc_span::{Span, sym};
4use thin_vec::thin_vec;
56use crate::deriving::generic::ty::*;
7use crate::deriving::generic::*;
8use crate::deriving::{path_local, path_std};
910/// Expands a `#[derive(PartialEq)]` attribute into an implementation for the
11/// target item.
12pub(crate) fn expand_deriving_partial_eq(
13 cx: &ExtCtxt<'_>,
14 span: Span,
15 mitem: &MetaItem,
16 item: &Annotatable,
17 push: &mut dyn FnMut(Annotatable),
18 is_const: bool,
19) {
20let structural_trait_def = TraitDef {
21span,
22 path: generic::ty::Path::new({
<[_]>::into_vec(::alloc::boxed::box_new([sym::marker,
sym::StructuralPartialEq]))
})path_std!(marker::StructuralPartialEq),
23 skip_path_as_bound: true, // crucial!
24needs_copy_as_bound_if_packed: false,
25 additional_bounds: Vec::new(),
26// We really don't support unions, but that's already checked by the impl generated below;
27 // a second check here would lead to redundant error messages.
28supports_unions: true,
29 methods: Vec::new(),
30 associated_types: Vec::new(),
31 is_const: false,
32 is_staged_api_crate: cx.ecfg.features.staged_api(),
33 safety: Safety::Default,
34 document: true,
35 };
36structural_trait_def.expand(cx, mitem, item, push);
3738// No need to generate `ne`, the default suffices, and not generating it is
39 // faster.
40let methods = <[_]>::into_vec(::alloc::boxed::box_new([MethodDef {
name: sym::eq,
generics: Bounds::empty(),
explicit_self: true,
nonself_args: <[_]>::into_vec(::alloc::boxed::box_new([(self_ref(),
sym::other)])),
ret_ty: Path(generic::ty::Path::new_local(sym::bool)),
attributes: {
let len = [()].len();
let mut vec = ::thin_vec::ThinVec::with_capacity(len);
vec.push(cx.attr_word(sym::inline, span));
vec
},
fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
combine_substructure: combine_substructure(Box::new(|a, b,
c|
{
BlockOrExpr::new_expr(get_substructure_equality_expr(a, b,
c))
})),
}]))vec![MethodDef {
41 name: sym::eq,
42 generics: Bounds::empty(),
43 explicit_self: true,
44 nonself_args: vec![(self_ref(), sym::other)],
45 ret_ty: Path(path_local!(bool)),
46 attributes: thin_vec![cx.attr_word(sym::inline, span)],
47 fieldless_variants_strategy: FieldlessVariantsStrategy::Unify,
48 combine_substructure: combine_substructure(Box::new(|a, b, c| {
49 BlockOrExpr::new_expr(get_substructure_equality_expr(a, b, c))
50 })),
51 }];
5253let trait_def = TraitDef {
54span,
55 path: generic::ty::Path::new({
<[_]>::into_vec(::alloc::boxed::box_new([sym::cmp, sym::PartialEq]))
})path_std!(cmp::PartialEq),
56 skip_path_as_bound: false,
57 needs_copy_as_bound_if_packed: true,
58 additional_bounds: Vec::new(),
59 supports_unions: false,
60methods,
61 associated_types: Vec::new(),
62is_const,
63 is_staged_api_crate: cx.ecfg.features.staged_api(),
64 safety: Safety::Default,
65 document: true,
66 };
67trait_def.expand(cx, mitem, item, push)
68}
6970/// Generates the equality expression for a struct or enum variant when deriving
71/// `PartialEq`.
72///
73/// This function generates an expression that checks if all fields of a struct
74/// or enum variant are equal.
75/// - Scalar fields are compared first for efficiency, followed by compound
76/// fields.
77/// - If there are no fields, returns `true` (fieldless types are always equal).
78///
79/// Whether a field is considered "scalar" is determined by comparing the symbol
80/// of its type to a set of known scalar type symbols (e.g., `i32`, `u8`, etc).
81/// This check is based on the type's symbol.
82///
83/// ### Example 1
84/// ```
85/// #[derive(PartialEq)]
86/// struct i32;
87///
88/// // Here, `field_2` is of type `i32`, but since it's a user-defined type (not
89/// // the primitive), it will not be treated as scalar. The function will still
90/// // check equality of `field_2` first because the symbol matches `i32`.
91/// #[derive(PartialEq)]
92/// struct Struct {
93/// field_1: &'static str,
94/// field_2: i32,
95/// }
96/// ```
97///
98/// ### Example 2
99/// ```
100/// mod ty {
101/// pub type i32 = i32;
102/// }
103///
104/// // Here, `field_2` is of type `ty::i32`, which is a type alias for `i32`.
105/// // However, the function will not reorder the fields because the symbol for
106/// // `ty::i32` does not match the symbol for the primitive `i32`
107/// // ("ty::i32" != "i32").
108/// #[derive(PartialEq)]
109/// struct Struct {
110/// field_1: &'static str,
111/// field_2: ty::i32,
112/// }
113/// ```
114///
115/// For enums, the discriminant is compared first, then the rest of the fields.
116///
117/// # Panics
118///
119/// If called on static or all-fieldless enums/structs, which should not occur
120/// during derive expansion.
121fn get_substructure_equality_expr(
122 cx: &ExtCtxt<'_>,
123 span: Span,
124 substructure: &Substructure<'_>,
125) -> Box<Expr> {
126use SubstructureFields::*;
127128match substructure.fields {
129EnumMatching(.., fields) | Struct(.., fields) => {
130let combine = move |acc, field| {
131let rhs = get_field_equality_expr(cx, field);
132if let Some(lhs) = acc {
133// Combine the previous comparison with the current field
134 // using logical AND.
135return Some(cx.expr_binary(field.span, BinOpKind::And, lhs, rhs));
136 }
137// Start the chain with the first field's comparison.
138Some(rhs)
139 };
140141// First compare scalar fields, then compound fields, combining all
142 // with logical AND.
143return fields144 .iter()
145 .filter(|field| !field.maybe_scalar)
146 .fold(fields.iter().filter(|field| field.maybe_scalar).fold(None, combine), combine)
147// If there are no fields, treat as always equal.
148.unwrap_or_else(|| cx.expr_bool(span, true));
149 }
150EnumDiscr(disc, match_expr) => {
151let lhs = get_field_equality_expr(cx, disc);
152let Some(match_expr) = match_exprelse {
153return lhs;
154 };
155// Compare the discriminant first (cheaper), then the rest of the
156 // fields.
157return cx.expr_binary(disc.span, BinOpKind::And, lhs, match_expr.clone());
158 }
159StaticEnum(..) => cx.dcx().span_bug(
160span,
161"unexpected static enum encountered during `derive(PartialEq)` expansion",
162 ),
163StaticStruct(..) => cx.dcx().span_bug(
164span,
165"unexpected static struct encountered during `derive(PartialEq)` expansion",
166 ),
167AllFieldlessEnum(..) => cx.dcx().span_bug(
168span,
169"unexpected all-fieldless enum encountered during `derive(PartialEq)` expansion",
170 ),
171 }
172}
173174/// Generates an equality comparison expression for a single struct or enum
175/// field.
176///
177/// This function produces an AST expression that compares the `self` and
178/// `other` values for a field using `==`. It removes any leading references
179/// from both sides for readability. If the field is a block expression, it is
180/// wrapped in parentheses to ensure valid syntax.
181///
182/// # Panics
183///
184/// Panics if there are not exactly two arguments to compare (should be `self`
185/// and `other`).
186fn get_field_equality_expr(cx: &ExtCtxt<'_>, field: &FieldInfo) -> Box<Expr> {
187let [rhs] = &field.other_selflike_exprs[..] else {
188cx.dcx().span_bug(field.span, "not exactly 2 arguments in `derive(PartialEq)`");
189 };
190191cx.expr_binary(
192field.span,
193 BinOpKind::Eq,
194wrap_block_expr(cx, peel_refs(&field.self_expr)),
195wrap_block_expr(cx, peel_refs(rhs)),
196 )
197}
198199/// Removes all leading immutable references from an expression.
200///
201/// This is used to strip away any number of leading `&` from an expression
202/// (e.g., `&&&T` becomes `T`). Only removes immutable references; mutable
203/// references are preserved.
204fn peel_refs(mut expr: &Box<Expr>) -> Box<Expr> {
205while let ExprKind::AddrOf(BorrowKind::Ref, Mutability::Not, inner) = &expr.kind {
206 expr = &inner;
207 }
208expr.clone()
209}
210211/// Wraps a block expression in parentheses to ensure valid AST in macro
212/// expansion output.
213///
214/// If the given expression is a block, it is wrapped in parentheses; otherwise,
215/// it is returned unchanged.
216fn wrap_block_expr(cx: &ExtCtxt<'_>, expr: Box<Expr>) -> Box<Expr> {
217if #[allow(non_exhaustive_omitted_patterns)] match &expr.kind {
ExprKind::Block(..) => true,
_ => false,
}matches!(&expr.kind, ExprKind::Block(..)) {
218return cx.expr_paren(expr.span, expr);
219 }
220expr221}