Skip to content

Commit a39525f

Browse files
committed
Use an aggregate equality comparison for constant array/slice patterns when possible
When every element in an array or slice pattern is a constant and there is no `..` subpattern, the match builder now emits a single call to `PartialEq::eq` instead of comparing each element one by one. This drastically reduces the number of MIR basic blocks for large constant-array matches – e.g. a 64-element `[u8; 64]` match previously generated 64 separate comparison blocks and now generates just one `PartialEq::eq` call that LLVM can lower to a `memcmp()` The optimisation is gated on having at least two constant elements. Single-element arrays still use a plain scalar comparison. Example: ```rust const FOO: [u8; 64] = *b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdef"; pub fn foo(x: &[u8; 64]) -> bool { // Before: 64 basic blocks, one per byte. // After: a single `PartialEq::eq()` call. matches!(x, &FOO) } ```
1 parent 14196db commit a39525f

4 files changed

Lines changed: 163 additions & 41 deletions

File tree

compiler/rustc_mir_build/src/builder/matches/buckets.rs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -323,6 +323,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
323323
value: case_val,
324324
kind: PatConstKind::Float | PatConstKind::Other,
325325
},
326+
)
327+
| (
328+
TestKind::AggregateEq { value: test_val, .. },
329+
TestableCase::Constant { value: case_val, kind: PatConstKind::Aggregate },
326330
) => {
327331
if test_val == case_val {
328332
fully_matched = true;
@@ -353,6 +357,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
353357
| TestKind::Range { .. }
354358
| TestKind::StringEq { .. }
355359
| TestKind::ScalarEq { .. }
360+
| TestKind::AggregateEq { .. }
356361
| TestKind::Deref { .. },
357362
_,
358363
) => {

compiler/rustc_mir_build/src/builder/matches/match_pair.rs

Lines changed: 101 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,41 @@ use rustc_abi::FieldIdx;
44
use rustc_middle::mir::*;
55
use rustc_middle::span_bug;
66
use rustc_middle::thir::*;
7-
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
7+
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};
88

99
use crate::builder::Builder;
1010
use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder};
1111
use crate::builder::matches::{
1212
FlatPat, MatchPairTree, PatConstKind, PatternExtraData, SliceLenOp, TestableCase,
1313
};
1414

15+
/// Checks whether every pattern in `elements` is a `PatKind::Constant` and,
16+
/// if so, reconstructs a single aggregate `ty::Value` that represents the whole
17+
/// array or slice. Returns `None` when any element is not a constant or the
18+
/// sequence is too short to benefit from an aggregate comparison.
19+
fn try_reconstruct_aggregate_constant<'tcx>(
20+
tcx: TyCtxt<'tcx>,
21+
aggregate_ty: Ty<'tcx>,
22+
elements: &[Pat<'tcx>],
23+
) -> Option<ty::Value<'tcx>> {
24+
// A single element (or empty array) is not worth an aggregate comparison.
25+
if elements.len() <= 1 {
26+
return None;
27+
}
28+
let branches = elements
29+
.iter()
30+
.map(|pat| {
31+
if let PatKind::Constant { value } = pat.kind {
32+
Some(ty::Const::new_value(tcx, value.valtree, value.ty))
33+
} else {
34+
None
35+
}
36+
})
37+
.collect::<Option<Vec<_>>>()?;
38+
let valtree = ty::ValTree::from_branches(tcx, branches);
39+
Some(ty::Value { ty: aggregate_ty, valtree })
40+
}
41+
1542
impl<'a, 'tcx> Builder<'a, 'tcx> {
1643
/// Builds and pushes [`MatchPairTree`] subtrees, one for each pattern in
1744
/// `subpatterns`, representing the fields of a [`PatKind::Variant`] or
@@ -239,15 +266,30 @@ impl<'tcx> MatchPairTree<'tcx> {
239266
_ => None,
240267
};
241268
if let Some(array_len) = array_len {
242-
cx.prefix_slice_suffix(
243-
&mut subpairs,
244-
extra_data,
245-
&place_builder,
246-
Some(array_len),
247-
prefix,
248-
slice,
249-
suffix,
250-
);
269+
// When all elements are constants and there is no `..`
270+
// subpattern, compare the whole array at once via
271+
// `PartialEq::eq` rather than element by element.
272+
if slice.is_none()
273+
&& suffix.is_empty()
274+
&& let Some(aggregate_value) =
275+
try_reconstruct_aggregate_constant(cx.tcx, pattern.ty, prefix)
276+
{
277+
Some(TestableCase::Constant {
278+
value: aggregate_value,
279+
kind: PatConstKind::Aggregate,
280+
})
281+
} else {
282+
cx.prefix_slice_suffix(
283+
&mut subpairs,
284+
extra_data,
285+
&place_builder,
286+
Some(array_len),
287+
prefix,
288+
slice,
289+
suffix,
290+
);
291+
None
292+
}
251293
} else {
252294
// If the array length couldn't be determined, ignore the
253295
// subpatterns and delayed-assert that compilation will fail.
@@ -258,37 +300,60 @@ impl<'tcx> MatchPairTree<'tcx> {
258300
pattern.ty
259301
),
260302
);
303+
None
261304
}
262-
263-
None
264305
}
265306
PatKind::Slice { ref prefix, ref slice, ref suffix } => {
266-
cx.prefix_slice_suffix(
267-
&mut subpairs,
268-
extra_data,
269-
&place_builder,
270-
None,
271-
prefix,
272-
slice,
273-
suffix,
274-
);
275-
276-
if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
277-
// This pattern is shaped like `[..]`. It can match a slice
278-
// of any length, so no length test is needed.
279-
None
280-
} else {
281-
// Any other shape of slice pattern requires a length test.
282-
// Slice patterns with a `..` subpattern require a minimum
283-
// length; those without `..` require an exact length.
284-
Some(TestableCase::Slice {
285-
len: u64::try_from(prefix.len() + suffix.len()).unwrap(),
286-
op: if slice.is_some() {
287-
SliceLenOp::GreaterOrEqual
288-
} else {
289-
SliceLenOp::Equal
307+
// When there is no `..`, all elements are constants, and
308+
// there are at least two of them, collapse the individual
309+
// element subpairs into a single aggregate comparison that
310+
// is performed after the length check.
311+
if slice.is_none()
312+
&& suffix.is_empty()
313+
&& let Some(aggregate_value) =
314+
try_reconstruct_aggregate_constant(cx.tcx, pattern.ty, prefix)
315+
{
316+
subpairs.push(MatchPairTree {
317+
place,
318+
testable_case: TestableCase::Constant {
319+
value: aggregate_value,
320+
kind: PatConstKind::Aggregate,
290321
},
322+
subpairs: Vec::new(),
323+
pattern_span: pattern.span,
324+
});
325+
Some(TestableCase::Slice {
326+
len: u64::try_from(prefix.len()).unwrap(),
327+
op: SliceLenOp::Equal,
291328
})
329+
} else {
330+
cx.prefix_slice_suffix(
331+
&mut subpairs,
332+
extra_data,
333+
&place_builder,
334+
None,
335+
prefix,
336+
slice,
337+
suffix,
338+
);
339+
340+
if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
341+
// This pattern is shaped like `[..]`. It can match
342+
// a slice of any length, so no length test is needed.
343+
None
344+
} else {
345+
// Any other shape of slice pattern requires a length test.
346+
// Slice patterns with a `..` subpattern require a minimum
347+
// length; those without `..` require an exact length.
348+
Some(TestableCase::Slice {
349+
len: u64::try_from(prefix.len() + suffix.len()).unwrap(),
350+
op: if slice.is_some() {
351+
SliceLenOp::GreaterOrEqual
352+
} else {
353+
SliceLenOp::Equal
354+
},
355+
})
356+
}
292357
}
293358
}
294359

compiler/rustc_mir_build/src/builder/matches/mod.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,6 +1266,10 @@ enum PatConstKind {
12661266
Float,
12671267
/// Constant string values, tested via string equality.
12681268
String,
1269+
/// Constant array or slice values where every element is a constant.
1270+
/// Tested by calling `PartialEq::eq` on the whole aggregate at once,
1271+
/// rather than comparing element by element.
1272+
Aggregate,
12691273
/// Any other constant-pattern is usually tested via some kind of equality
12701274
/// check. Types that might be encountered here include:
12711275
/// - raw pointers derived from integer values
@@ -1351,6 +1355,10 @@ enum TestKind<'tcx> {
13511355
/// Tests the place against a constant using scalar equality.
13521356
ScalarEq { value: ty::Value<'tcx> },
13531357

1358+
/// Tests the place against a constant array or slice using `PartialEq::eq`,
1359+
/// comparing the whole aggregate at once rather than element by element.
1360+
AggregateEq { value: ty::Value<'tcx> },
1361+
13541362
/// Test whether the value falls within an inclusive or exclusive range.
13551363
Range(Arc<PatRange<'tcx>>),
13561364

compiler/rustc_mir_build/src/builder/matches/test.rs

Lines changed: 49 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
4040
TestableCase::Constant { value, kind: PatConstKind::String } => {
4141
TestKind::StringEq { value }
4242
}
43+
TestableCase::Constant { value, kind: PatConstKind::Aggregate } => {
44+
TestKind::AggregateEq { value }
45+
}
4346
TestableCase::Constant { value, kind: PatConstKind::Float | PatConstKind::Other } => {
4447
TestKind::ScalarEq { value }
4548
}
@@ -168,16 +171,54 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
168171
// Compare two strings using `<str as std::cmp::PartialEq>::eq`.
169172
// (Interestingly this means that exhaustiveness analysis relies, for soundness,
170173
// on the `PartialEq` impl for `str` to be correct!)
171-
self.string_compare(
174+
self.non_scalar_compare(
172175
block,
173176
success_block,
174177
fail_block,
175178
source_info,
179+
tcx.types.str_,
176180
expected_value_operand,
177181
Operand::Copy(actual_value_ref_place),
178182
);
179183
}
180184

185+
TestKind::AggregateEq { value } => {
186+
let tcx = self.tcx;
187+
let success_block = target_block(TestBranch::Success);
188+
let fail_block = target_block(TestBranch::Failure);
189+
190+
let aggregate_ty = value.ty;
191+
let ref_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, aggregate_ty);
192+
193+
// The constant has type `[T; N]` (or `[T]`), but calling
194+
// `PartialEq::eq` requires `&[T; N]` (or `&[T]`) operands.
195+
// Valtree representations are the same with or without the
196+
// reference wrapper, so we can reinterpret by replacing the type.
197+
let expected_value = ty::Value { ty: ref_ty, valtree: value.valtree };
198+
let expected_operand =
199+
self.literal_operand(test.span, Const::from_ty_value(tcx, expected_value));
200+
201+
// Create a reference to the scrutinee place.
202+
let actual_ref_place = self.temp(ref_ty, test.span);
203+
self.cfg.push_assign(
204+
block,
205+
self.source_info(test.span),
206+
actual_ref_place,
207+
Rvalue::Ref(tcx.lifetimes.re_erased, BorrowKind::Shared, place),
208+
);
209+
210+
// Compare using `<T as PartialEq>::eq` where `T` is the array or slice type.
211+
self.non_scalar_compare(
212+
block,
213+
success_block,
214+
fail_block,
215+
source_info,
216+
aggregate_ty,
217+
expected_operand,
218+
Operand::Copy(actual_ref_place),
219+
);
220+
}
221+
181222
TestKind::ScalarEq { value } => {
182223
let tcx = self.tcx;
183224
let success_block = target_block(TestBranch::Success);
@@ -404,19 +445,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
404445
);
405446
}
406447

407-
/// Compare two values of type `&str` using `<str as std::cmp::PartialEq>::eq`.
408-
fn string_compare(
448+
/// Compare two reference values using `<T as PartialEq>::eq`.
449+
///
450+
/// `compared_ty` is the *inner* type (e.g. `str`, `[u8; 64]`);
451+
/// `expect` and `val` must already be references to that type.
452+
fn non_scalar_compare(
409453
&mut self,
410454
block: BasicBlock,
411455
success_block: BasicBlock,
412456
fail_block: BasicBlock,
413457
source_info: SourceInfo,
458+
compared_ty: Ty<'tcx>,
414459
expect: Operand<'tcx>,
415460
val: Operand<'tcx>,
416461
) {
417-
let str_ty = self.tcx.types.str_;
418462
let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, source_info.span);
419-
let method = trait_method(self.tcx, eq_def_id, sym::eq, [str_ty, str_ty]);
463+
let method = trait_method(self.tcx, eq_def_id, sym::eq, [compared_ty, compared_ty]);
420464

421465
let bool_ty = self.tcx.types.bool;
422466
let eq_result = self.temp(bool_ty, source_info.span);

0 commit comments

Comments
 (0)