Skip to content

Commit 5a6dba6

Browse files
committed
Auto merge of #155216 - jakubadamw:issue-110870-103073, r=<try>
match: Use an aggregate equality comparison for constant array/slice patterns when possible
2 parents 14196db + 0cd52bb commit 5a6dba6

20 files changed

Lines changed: 1723 additions & 41 deletions

compiler/rustc_mir_build/src/builder/matches/buckets.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
323323
value: case_val,
324324
kind: PatConstKind::Float | PatConstKind::Other,
325325
},
326+
)
327+
| (
328+
TestKind::AggregateEq { value: test_val, .. },
329+
TestableCase::Constant { value: case_val, kind: PatConstKind::Aggregate },
326330
) => {
327331
if test_val == case_val {
328332
fully_matched = true;
@@ -353,6 +357,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
353357
| TestKind::Range { .. }
354358
| TestKind::StringEq { .. }
355359
| TestKind::ScalarEq { .. }
360+
| TestKind::AggregateEq { .. }
356361
| TestKind::Deref { .. },
357362
_,
358363
) => {

compiler/rustc_mir_build/src/builder/matches/match_pair.rs

Lines changed: 117 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,15 +4,56 @@ use rustc_abi::FieldIdx;
44
use rustc_middle::mir::*;
55
use rustc_middle::span_bug;
66
use rustc_middle::thir::*;
7-
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
7+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
8+
use rustc_span::sym;
89

910
use crate::builder::Builder;
1011
use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder};
1112
use crate::builder::matches::{
1213
FlatPat, MatchPairTree, PatConstKind, PatternExtraData, SliceLenOp, TestableCase,
1314
};
1415

16+
/// Checks whether every pattern in `elements` is a `PatKind::Constant` and,
17+
/// if so, reconstructs a single aggregate `ty::Value` that represents the whole
18+
/// array or slice. Returns `None` when any element is not a constant or the
19+
/// sequence is too short to benefit from an aggregate comparison.
20+
fn try_reconstruct_aggregate_constant<'tcx>(
21+
tcx: TyCtxt<'tcx>,
22+
aggregate_ty: Ty<'tcx>,
23+
elements: &[Pat<'tcx>],
24+
) -> Option<ty::Value<'tcx>> {
25+
// A single element (or empty array) is not worth an aggregate comparison.
26+
if elements.len() <= 1 {
27+
return None;
28+
}
29+
let branches = elements
30+
.iter()
31+
.map(|pat| {
32+
if let PatKind::Constant { value } = pat.kind {
33+
Some(ty::Const::new_value(tcx, value.valtree, value.ty))
34+
} else {
35+
None
36+
}
37+
})
38+
.collect::<Option<Vec<_>>>()?;
39+
let valtree = ty::ValTree::from_branches(tcx, branches);
40+
Some(ty::Value { ty: aggregate_ty, valtree })
41+
}
42+
1543
impl<'a, 'tcx> Builder<'a, 'tcx> {
44+
/// Check if we can use aggregate `PartialEq::eq` comparisons for constant array/slice patterns.
45+
/// This is not possible in const contexts unless `#![feature(const_cmp, const_trait_impl)]` are enabled,
46+
/// because`PartialEq` is not const-stable.
47+
fn can_use_aggregate_eq(&self) -> bool {
48+
let const_partial_eq_enabled = {
49+
let features = self.tcx.features();
50+
features.enabled(sym::const_trait_impl) && features.enabled(sym::const_cmp)
51+
};
52+
let in_const_context = self.tcx.is_const_fn(self.def_id.to_def_id())
53+
|| !self.tcx.hir_body_owner_kind(self.def_id).is_fn_or_closure();
54+
!in_const_context || const_partial_eq_enabled
55+
}
56+
1657
/// Builds and pushes [`MatchPairTree`] subtrees, one for each pattern in
1758
/// `subpatterns`, representing the fields of a [`PatKind::Variant`] or
1859
/// [`PatKind::Leaf`].
@@ -239,15 +280,31 @@ impl<'tcx> MatchPairTree<'tcx> {
239280
_ => None,
240281
};
241282
if let Some(array_len) = array_len {
242-
cx.prefix_slice_suffix(
243-
&mut subpairs,
244-
extra_data,
245-
&place_builder,
246-
Some(array_len),
247-
prefix,
248-
slice,
249-
suffix,
250-
);
283+
// When all elements are constants and there is no `..`
284+
// subpattern, compare the whole array at once via
285+
// `PartialEq::eq` rather than element by element.
286+
if slice.is_none()
287+
&& suffix.is_empty()
288+
&& cx.can_use_aggregate_eq()
289+
&& let Some(aggregate_value) =
290+
try_reconstruct_aggregate_constant(cx.tcx, pattern.ty, prefix)
291+
{
292+
Some(TestableCase::Constant {
293+
value: aggregate_value,
294+
kind: PatConstKind::Aggregate,
295+
})
296+
} else {
297+
cx.prefix_slice_suffix(
298+
&mut subpairs,
299+
extra_data,
300+
&place_builder,
301+
Some(array_len),
302+
prefix,
303+
slice,
304+
suffix,
305+
);
306+
None
307+
}
251308
} else {
252309
// If the array length couldn't be determined, ignore the
253310
// subpatterns and delayed-assert that compilation will fail.
@@ -258,37 +315,61 @@ impl<'tcx> MatchPairTree<'tcx> {
258315
pattern.ty
259316
),
260317
);
318+
None
261319
}
262-
263-
None
264320
}
265321
PatKind::Slice { ref prefix, ref slice, ref suffix } => {
266-
cx.prefix_slice_suffix(
267-
&mut subpairs,
268-
extra_data,
269-
&place_builder,
270-
None,
271-
prefix,
272-
slice,
273-
suffix,
274-
);
275-
276-
if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
277-
// This pattern is shaped like `[..]`. It can match a slice
278-
// of any length, so no length test is needed.
279-
None
280-
} else {
281-
// Any other shape of slice pattern requires a length test.
282-
// Slice patterns with a `..` subpattern require a minimum
283-
// length; those without `..` require an exact length.
284-
Some(TestableCase::Slice {
285-
len: u64::try_from(prefix.len() + suffix.len()).unwrap(),
286-
op: if slice.is_some() {
287-
SliceLenOp::GreaterOrEqual
288-
} else {
289-
SliceLenOp::Equal
322+
// When there is no `..`, all elements are constants, and
323+
// there are at least two of them, collapse the individual
324+
// element subpairs into a single aggregate comparison that
325+
// is performed after the length check.
326+
if slice.is_none()
327+
&& suffix.is_empty()
328+
&& cx.can_use_aggregate_eq()
329+
&& let Some(aggregate_value) =
330+
try_reconstruct_aggregate_constant(cx.tcx, pattern.ty, prefix)
331+
{
332+
subpairs.push(MatchPairTree {
333+
place,
334+
testable_case: TestableCase::Constant {
335+
value: aggregate_value,
336+
kind: PatConstKind::Aggregate,
290337
},
338+
subpairs: Vec::new(),
339+
pattern_span: pattern.span,
340+
});
341+
Some(TestableCase::Slice {
342+
len: u64::try_from(prefix.len()).unwrap(),
343+
op: SliceLenOp::Equal,
291344
})
345+
} else {
346+
cx.prefix_slice_suffix(
347+
&mut subpairs,
348+
extra_data,
349+
&place_builder,
350+
None,
351+
prefix,
352+
slice,
353+
suffix,
354+
);
355+
356+
if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
357+
// This pattern is shaped like `[..]`. It can match
358+
// a slice of any length, so no length test is needed.
359+
None
360+
} else {
361+
// Any other shape of slice pattern requires a length test.
362+
// Slice patterns with a `..` subpattern require a minimum
363+
// length; those without `..` require an exact length.
364+
Some(TestableCase::Slice {
365+
len: u64::try_from(prefix.len() + suffix.len()).unwrap(),
366+
op: if slice.is_some() {
367+
SliceLenOp::GreaterOrEqual
368+
} else {
369+
SliceLenOp::Equal
370+
},
371+
})
372+
}
292373
}
293374
}
294375

compiler/rustc_mir_build/src/builder/matches/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,10 @@ enum PatConstKind {
12661266
Float,
12671267
/// Constant string values, tested via string equality.
12681268
String,
1269+
/// Constant array or slice values where every element is a constant.
1270+
/// Tested by calling `PartialEq::eq` on the whole aggregate at once,
1271+
/// rather than comparing element by element.
1272+
Aggregate,
12691273
/// Any other constant-pattern is usually tested via some kind of equality
12701274
/// check. Types that might be encountered here include:
12711275
/// - raw pointers derived from integer values
@@ -1351,6 +1355,10 @@ enum TestKind<'tcx> {
13511355
/// Tests the place against a constant using scalar equality.
13521356
ScalarEq { value: ty::Value<'tcx> },
13531357

1358+
/// Tests the place against a constant array or slice using `PartialEq::eq`,
1359+
/// comparing the whole aggregate at once rather than element by element.
1360+
AggregateEq { value: ty::Value<'tcx> },
1361+
13541362
/// Test whether the value falls within an inclusive or exclusive range.
13551363
Range(Arc<PatRange<'tcx>>),
13561364

compiler/rustc_mir_build/src/builder/matches/test.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
4040
TestableCase::Constant { value, kind: PatConstKind::String } => {
4141
TestKind::StringEq { value }
4242
}
43+
TestableCase::Constant { value, kind: PatConstKind::Aggregate } => {
44+
TestKind::AggregateEq { value }
45+
}
4346
TestableCase::Constant { value, kind: PatConstKind::Float | PatConstKind::Other } => {
4447
TestKind::ScalarEq { value }
4548
}
@@ -168,16 +171,54 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
168171
// Compare two strings using `<str as std::cmp::PartialEq>::eq`.
169172
// (Interestingly this means that exhaustiveness analysis relies, for soundness,
170173
// on the `PartialEq` impl for `str` to be correct!)
171-
self.string_compare(
174+
self.non_scalar_compare(
172175
block,
173176
success_block,
174177
fail_block,
175178
source_info,
179+
tcx.types.str_,
176180
expected_value_operand,
177181
Operand::Copy(actual_value_ref_place),
178182
);
179183
}
180184

185+
TestKind::AggregateEq { value } => {
186+
let tcx = self.tcx;
187+
let success_block = target_block(TestBranch::Success);
188+
let fail_block = target_block(TestBranch::Failure);
189+
190+
let aggregate_ty = value.ty;
191+
let ref_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, aggregate_ty);
192+
193+
// The constant has type `[T; N]` (or `[T]`), but calling
194+
// `PartialEq::eq` requires `&[T; N]` (or `&[T]`) operands.
195+
// Valtree representations are the same with or without the
196+
// reference wrapper, so we can reinterpret by replacing the type.
197+
let expected_value = ty::Value { ty: ref_ty, valtree: value.valtree };
198+
let expected_operand =
199+
self.literal_operand(test.span, Const::from_ty_value(tcx, expected_value));
200+
201+
// Create a reference to the scrutinee place.
202+
let actual_ref_place = self.temp(ref_ty, test.span);
203+
self.cfg.push_assign(
204+
block,
205+
self.source_info(test.span),
206+
actual_ref_place,
207+
Rvalue::Ref(tcx.lifetimes.re_erased, BorrowKind::Shared, place),
208+
);
209+
210+
// Compare using `<T as PartialEq>::eq` where `T` is the array or slice type.
211+
self.non_scalar_compare(
212+
block,
213+
success_block,
214+
fail_block,
215+
source_info,
216+
aggregate_ty,
217+
expected_operand,
218+
Operand::Copy(actual_ref_place),
219+
);
220+
}
221+
181222
TestKind::ScalarEq { value } => {
182223
let tcx = self.tcx;
183224
let success_block = target_block(TestBranch::Success);
@@ -404,19 +445,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
404445
);
405446
}
406447

407-
/// Compare two values of type `&str` using `<str as std::cmp::PartialEq>::eq`.
408-
fn string_compare(
448+
/// Compare two reference values using `<T as PartialEq>::eq`.
449+
///
450+
/// `compared_ty` is the *inner* type (e.g. `str`, `[u8; 64]`);
451+
/// `expect` and `val` must already be references to that type.
452+
fn non_scalar_compare(
409453
&mut self,
410454
block: BasicBlock,
411455
success_block: BasicBlock,
412456
fail_block: BasicBlock,
413457
source_info: SourceInfo,
458+
compared_ty: Ty<'tcx>,
414459
expect: Operand<'tcx>,
415460
val: Operand<'tcx>,
416461
) {
417-
let str_ty = self.tcx.types.str_;
418462
let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, source_info.span);
419-
let method = trait_method(self.tcx, eq_def_id, sym::eq, [str_ty, str_ty]);
463+
let method = trait_method(self.tcx, eq_def_id, sym::eq, [compared_ty, compared_ty]);
420464

421465
let bool_ty = self.tcx.types.bool;
422466
let eq_result = self.temp(bool_ty, source_info.span);

compiler/rustc_span/src/symbol.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -649,6 +649,7 @@ symbols! {
649649
const_block_items,
650650
const_c_variadic,
651651
const_closures,
652+
const_cmp,
652653
const_compare_raw_pointers,
653654
const_constructor,
654655
const_continue,
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// MIR for `array_match` after built
2+
3+
fn array_match(_1: [u8; 4]) -> bool {
4+
debug x => _1;
5+
let mut _0: bool;
6+
let mut _2: &[u8; 4];
7+
let mut _3: bool;
8+
scope 1 {
9+
}
10+
11+
bb0: {
12+
PlaceMention(_1);
13+
_2 = &_1;
14+
_3 = <[u8; 4] as PartialEq>::eq(copy _2, const &*b"\x01\x02\x03\x04") -> [return: bb4, unwind: bb8];
15+
}
16+
17+
bb1: {
18+
_0 = const false;
19+
goto -> bb7;
20+
}
21+
22+
bb2: {
23+
falseEdge -> [real: bb6, imaginary: bb1];
24+
}
25+
26+
bb3: {
27+
goto -> bb1;
28+
}
29+
30+
bb4: {
31+
switchInt(move _3) -> [0: bb1, otherwise: bb2];
32+
}
33+
34+
bb5: {
35+
FakeRead(ForMatchedPlace(None), _1);
36+
unreachable;
37+
}
38+
39+
bb6: {
40+
_0 = const true;
41+
goto -> bb7;
42+
}
43+
44+
bb7: {
45+
return;
46+
}
47+
48+
bb8 (cleanup): {
49+
resume;
50+
}
51+
}

0 commit comments

Comments
 (0)