Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions compiler/rustc_mir_build/src/builder/matches/buckets.rs
Original file line number Diff line number Diff line change
Expand Up @@ -323,6 +323,10 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
value: case_val,
kind: PatConstKind::Float | PatConstKind::Other,
},
)
| (
TestKind::AggregateEq { value: test_val, .. },
TestableCase::Constant { value: case_val, kind: PatConstKind::Aggregate },
) => {
if test_val == case_val {
fully_matched = true;
Expand Down Expand Up @@ -353,6 +357,7 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
| TestKind::Range { .. }
| TestKind::StringEq { .. }
| TestKind::ScalarEq { .. }
| TestKind::AggregateEq { .. }
| TestKind::Deref { .. },
_,
) => {
Expand Down
142 changes: 117 additions & 25 deletions compiler/rustc_mir_build/src/builder/matches/match_pair.rs
Copy link
Copy Markdown
Member

@Nadrieril Nadrieril Apr 22, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The general approach feels unfortunate: we transformed constants into patterns in const_to_pat, and now we try to reverse that transformation. Have we tried keeping the original constant around in the output of const_to_pat and using it at runtime? Tho this has the same issue of const-dependent MIR lowering that @dianne pointed out.

View changes since the review

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we do that, maybe we could also synthesize a constant during THIR building for hand-written array/slice patterns, like this PR currently does in MIR building? That way, we wouldn't end up worse codegen for hand-written array patterns than for const array items used as patterns.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I dunno, small handwritten arrays behave pretty much like tuples, it could be better codegen sometimes not to make them into constants:

match foo {
  // compiles to three nested `if`s today, would become 8 sequential `if`s if turned into constants
  [false, false, false] => ...,
  [false, false, true] => ...,
  [false, true, false] => ...,
  ...
}

It's admittedly a stretch but I'm tempted to err on the side of respecting user intent especially before the MIR boundary.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh, right. Maybe we should have a test for that in mir-opt/building/match/sort_candidates.rs or such if we start optimizing array/slice comparisons? I think as-is this PR may also turn that into 8 sequential tests, each a TestKind::AggregateEq for a different constant.

Copy link
Copy Markdown
Contributor Author

@jakubadamw jakubadamw Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Nadrieril, as an alternative to being guided solely by user “intent”, would it be sensible to raise the threshold on the number of elements in an array pattern where we would use the aggregate equality? Right now it’s > 2. Perhaps 4 would work better? I suppose a quantitative comparison with benchmarks could be of use here, but sadly I can’t commit to that with my present schedule.

Copy link
Copy Markdown
Member

@Nadrieril Nadrieril Apr 27, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

On further thought, I think this is an unsound optimization, for example:

/// Safety: if the first byte is zero the rest of the slice
/// must be initialized; otherwise all bets are off.
unsafe fn foo(x: *const [u8]) {
    unsafe { match *x {
        [0, 1] => ...,
        _ => ...,
    } }
}

This function is sound if patterns were guaranteed to match left-to-right (this isn't set in stone yet, but likely).

Here if you replace this with if PartialEq::eq(&*x, &[0, 1]) not only are you taking a reference which itself has opsem consequences (this may make some later foreign writes invalid), but also may be reading uninitialized data.

I think this convinces me we shouldn't even attempt to do this. Preserving "user intent" actually means preserving correct semantics here.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe it'd be most appropriate to instead fix whatever changed to make LLVM stop doing this optimization? Looking at #110870 (comment), the match in the example function f in that issue used to get simplified down enough in LLVM that the SLP vectorizer pass could turn the sequence of comparisons into a vector comparison in LLVM IR. I'm not familiar enough with LLVM to say what changed or how to fix it, though.

The other place I could imagine doing something like this would be a MIR transform pass, but that seems iffy. I haven't followed the relevant opsem discussions super closely (other than reading rust-lang/unsafe-code-guidelines#346 just now) but it seems like even turning the safe match on a normal reference in f in #110870 into a single comparison in MIR could introduce UB, judging from testing with CTFE and Miri (without -Zmiri-recursive-validation): those don't detect UB if f is given a reference to partially-initialized array if it short-circuits before reading anything uninit; of course, it's UB if the match becomes a call to a comparison intrinsic like a raw_eq or compare_bytes. Maybe a new comparison intrinsic or two with the right semantics could both work and give us more control over the LLVM IR we emit, but I'm guessing that's extreme overkill.

Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,56 @@ use rustc_abi::FieldIdx;
use rustc_middle::mir::*;
use rustc_middle::span_bug;
use rustc_middle::thir::*;
use rustc_middle::ty::{self, Ty, TypeVisitableExt};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitableExt};

use crate::builder::Builder;
use crate::builder::expr::as_place::{PlaceBase, PlaceBuilder};
use crate::builder::matches::{
FlatPat, MatchPairTree, PatConstKind, PatternExtraData, SliceLenOp, TestableCase,
};

/// Below this length, an array or slice pattern is compared element by element
/// rather than as a single aggregate, since the per-element comparisons are
/// unlikely to be more expensive than a `PartialEq::eq` call.
const AGGREGATE_EQ_MIN_LEN: usize = 4;

/// Checks whether every pattern in `elements` is a `PatKind::Constant` and,
/// if so, reconstructs a single aggregate `ty::Value` that represents the whole
/// array or slice. Returns `None` when any element is not a constant or the
/// sequence is too short to benefit from an aggregate comparison.
fn try_reconstruct_aggregate_constant<'tcx>(
tcx: TyCtxt<'tcx>,
aggregate_ty: Ty<'tcx>,
elements: &[Pat<'tcx>],
) -> Option<ty::Value<'tcx>> {
// Short arrays are not worth an aggregate comparison.
if elements.len() < AGGREGATE_EQ_MIN_LEN {
return None;
}
let branches = elements
.iter()
.map(|pat| {
if let PatKind::Constant { value } = pat.kind {
Some(ty::Const::new_value(tcx, value.valtree, value.ty))
} else {
None
}
Comment on lines +36 to +40
Copy link
Copy Markdown
Contributor

@dianne dianne Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might also be worth reconstructing aggregate constants for arrays/slices of arrays of constants, etc.? I'm not a specialization expert, but it looks like arrays of bytewise-comparable things are also bytewise-comparable, at least for common array lengths1. Since array and slice equality are specialized based on their element types' bytewise-comparability, we should be able to get better codegen for nested array patterns too (as long as the inner arrays are of one of those common lengths), I think?

View changes since the review

Footnotes

  1. https://github.com/rust-lang/rust/blob/f29256dd1420dc681bf4956e3012ffe9eccdc7e7/library/core/src/cmp/bytewise.rs#L74-L85

Copy link
Copy Markdown
Contributor Author

@jakubadamw jakubadamw Apr 26, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@dianne, interesting. I’ll look into this next! 🙂

})
.collect::<Option<Vec<_>>>()?;
let valtree = ty::ValTree::from_branches(tcx, branches);
Some(ty::Value { ty: aggregate_ty, valtree })
}

impl<'a, 'tcx> Builder<'a, 'tcx> {
/// Check if we can use aggregate `PartialEq::eq` comparisons for constant array/slice patterns.
/// This is not possible in const contexts, because `PartialEq` is not const-stable yet.
fn can_use_aggregate_eq(&self) -> bool {
let in_const_context = self.tcx.is_const_fn(self.def_id.to_def_id())
|| !self.tcx.hir_body_owner_kind(self.def_id).is_fn_or_closure();
!in_const_context
}
Comment thread
jakubadamw marked this conversation as resolved.
}

/// For an array or slice pattern's subpatterns (prefix/slice/suffix), returns a list
/// of those subpatterns, each paired with a suitably-projected [`PlaceBuilder`].
fn prefix_slice_suffix<'a, 'tcx>(
Expand Down Expand Up @@ -220,10 +262,36 @@ impl<'tcx> MatchPairTree<'tcx> {
_ => None,
};
if let Some(array_len) = array_len {
for (subplace, subpat) in
prefix_slice_suffix(&place_builder, Some(array_len), prefix, slice, suffix)
// When all elements are constants and there is no `..`
// subpattern, compare the whole array at once via
// `PartialEq::eq` rather than element by element.
if slice.is_none()
&& suffix.is_empty()
&& cx.can_use_aggregate_eq()
&& let Some(aggregate_value) =
try_reconstruct_aggregate_constant(cx.tcx, pattern.ty, prefix)
{
MatchPairTree::for_pattern(subplace, subpat, cx, &mut subpairs, extra_data);
Some(TestableCase::Constant {
value: aggregate_value,
kind: PatConstKind::Aggregate,
})
} else {
for (subplace, subpat) in prefix_slice_suffix(
&place_builder,
Some(array_len),
prefix,
slice,
suffix,
) {
MatchPairTree::for_pattern(
subplace,
subpat,
cx,
&mut subpairs,
extra_data,
);
}
None
}
} else {
// If the array length couldn't be determined, ignore the
Expand All @@ -235,33 +303,57 @@ impl<'tcx> MatchPairTree<'tcx> {
pattern.ty
),
);
None
}

None
}
PatKind::Slice { ref prefix, ref slice, ref suffix } => {
for (subplace, subpat) in
prefix_slice_suffix(&place_builder, None, prefix, slice, suffix)
// When there is no `..`, all elements are constants, and
// there are at least two of them, collapse the individual
// element subpairs into a single aggregate comparison that
// is performed after the length check.
if slice.is_none()
Comment on lines +310 to +314
Copy link
Copy Markdown
Contributor

@dianne dianne Apr 17, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

An additional possibility: even if there is a .., the comparisons for the sub-slices before and after the .. could be done via aggregate equality when applicable. Credit to #121540, which I think did this?

Edit: assuming prefixes and suffixes are typically small and hand-written, it's probably not worth the trouble to use aggregate equality for them.

Even if only handling the case with no .., it might be worth moving the special-casing into prefix_slice_suffix to share it between PatKind::Slice and PatKind::Array, since that's where the commonalities live.

Edit: after prefix_slice_suffix's cleanup in #154943, I don't think it makes much sense to put this in there. I still think the logic for deciding whether to use aggregate equality is complex enough that it could be worth factoring out, but that's probably not the way to do it.

View changes since the review

&& suffix.is_empty()
&& cx.can_use_aggregate_eq()
&& let Some(aggregate_value) =
try_reconstruct_aggregate_constant(cx.tcx, pattern.ty, prefix)
{
MatchPairTree::for_pattern(subplace, subpat, cx, &mut subpairs, extra_data);
}

if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
// This pattern is shaped like `[..]`. It can match a slice
// of any length, so no length test is needed.
None
} else {
// Any other shape of slice pattern requires a length test.
// Slice patterns with a `..` subpattern require a minimum
// length; those without `..` require an exact length.
Some(TestableCase::Slice {
len: u64::try_from(prefix.len() + suffix.len()).unwrap(),
op: if slice.is_some() {
SliceLenOp::GreaterOrEqual
} else {
SliceLenOp::Equal
subpairs.push(MatchPairTree {
place,
testable_case: TestableCase::Constant {
value: aggregate_value,
kind: PatConstKind::Aggregate,
},
subpairs: Vec::new(),
pattern_span: pattern.span,
});
Some(TestableCase::Slice {
len: u64::try_from(prefix.len()).unwrap(),
op: SliceLenOp::Equal,
})
} else {
for (subplace, subpat) in
prefix_slice_suffix(&place_builder, None, prefix, slice, suffix)
{
MatchPairTree::for_pattern(subplace, subpat, cx, &mut subpairs, extra_data);
}

if prefix.is_empty() && slice.is_some() && suffix.is_empty() {
// This pattern is shaped like `[..]`. It can match
// a slice of any length, so no length test is needed.
None
} else {
// Any other shape of slice pattern requires a length test.
// Slice patterns with a `..` subpattern require a minimum
// length; those without `..` require an exact length.
Some(TestableCase::Slice {
len: u64::try_from(prefix.len() + suffix.len()).unwrap(),
op: if slice.is_some() {
SliceLenOp::GreaterOrEqual
} else {
SliceLenOp::Equal
},
})
}
}
}

Expand Down
8 changes: 8 additions & 0 deletions compiler/rustc_mir_build/src/builder/matches/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1266,6 +1266,10 @@ enum PatConstKind {
Float,
/// Constant string values, tested via string equality.
String,
/// Constant array or slice values where every element is a constant.
/// Tested by calling `PartialEq::eq` on the whole aggregate at once,
/// rather than comparing element by element.
Aggregate,
/// Any other constant-pattern is usually tested via some kind of equality
/// check. Types that might be encountered here include:
/// - raw pointers derived from integer values
Expand Down Expand Up @@ -1351,6 +1355,10 @@ enum TestKind<'tcx> {
/// Tests the place against a constant using scalar equality.
ScalarEq { value: ty::Value<'tcx> },

/// Tests the place against a constant array or slice using `PartialEq::eq`,
/// comparing the whole aggregate at once rather than element by element.
AggregateEq { value: ty::Value<'tcx> },

/// Test whether the value falls within an inclusive or exclusive range.
Range(Arc<PatRange<'tcx>>),

Expand Down
56 changes: 34 additions & 22 deletions compiler/rustc_mir_build/src/builder/matches/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
TestableCase::Constant { value, kind: PatConstKind::String } => {
TestKind::StringEq { value }
}
TestableCase::Constant { value, kind: PatConstKind::Aggregate } => {
TestKind::AggregateEq { value }
}
TestableCase::Constant { value, kind: PatConstKind::Float | PatConstKind::Other } => {
TestKind::ScalarEq { value }
}
Expand Down Expand Up @@ -137,42 +140,48 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
self.cfg.terminate(block, self.source_info(match_start_span), terminator);
}

TestKind::StringEq { value } => {
TestKind::StringEq { value } | TestKind::AggregateEq { value } => {
let tcx = self.tcx;
let success_block = target_block(TestBranch::Success);
let fail_block = target_block(TestBranch::Failure);

let ref_str_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, tcx.types.str_);
assert!(ref_str_ty.is_imm_ref_str(), "{ref_str_ty:?}");

// The string constant we're testing against has type `str`, but
// calling `<str as PartialEq>::eq` requires `&str` operands.
//
// Because `str` and `&str` have the same valtree representation,
// we can "cast" to the desired type by just replacing the type.
assert!(value.ty.is_str(), "unexpected value type for StringEq test: {value:?}");
let expected_value = ty::Value { ty: ref_str_ty, valtree: value.valtree };
let inner_ty = value.ty;
if matches!(test.kind, TestKind::StringEq { .. }) {
assert!(
inner_ty.is_str(),
"unexpected value type for StringEq test: {value:?}"
);
}
let ref_ty = Ty::new_imm_ref(tcx, tcx.lifetimes.re_erased, inner_ty);

// The constant we're testing against has type `str`, `[T; N]`, or `[T]`,
// but calling `<T as PartialEq>::eq` requires a reference operand
// (`&str`, `&[T; N]`, or `&[T]`). Valtree representations are the same
// with or without the reference wrapper, so we can "cast" to the
// desired type by just replacing the type.
let expected_value = ty::Value { ty: ref_ty, valtree: value.valtree };
let expected_value_operand =
self.literal_operand(test.span, Const::from_ty_value(tcx, expected_value));

// Similarly, the scrutinized place has type `str`, but we need `&str`.
// Get a reference by doing `let actual_value_ref_place: &str = &place`.
let actual_value_ref_place = self.temp(ref_str_ty, test.span);
// Similarly, the scrutinised place has the inner type, but we need a
// reference. Get one by doing `let actual_value_ref_place = &place`.
let actual_value_ref_place = self.temp(ref_ty, test.span);
self.cfg.push_assign(
block,
self.source_info(test.span),
actual_value_ref_place,
Rvalue::Ref(tcx.lifetimes.re_erased, BorrowKind::Shared, place),
);

// Compare two strings using `<str as std::cmp::PartialEq>::eq`.
// (Interestingly this means that exhaustiveness analysis relies, for soundness,
// on the `PartialEq` impl for `str` to be correct!)
self.string_compare(
// Compare the two values using `<T as std::cmp::PartialEq>::eq`.
// (Interestingly this means that, for `str`, exhaustiveness analysis
// relies for soundness on the `PartialEq` impl for `str` to be correct!)
self.non_scalar_compare(
block,
success_block,
fail_block,
source_info,
inner_ty,
expected_value_operand,
Operand::Copy(actual_value_ref_place),
);
Expand Down Expand Up @@ -404,19 +413,22 @@ impl<'a, 'tcx> Builder<'a, 'tcx> {
);
}

/// Compare two values of type `&str` using `<str as std::cmp::PartialEq>::eq`.
fn string_compare(
/// Compare two reference values using `<T as PartialEq>::eq`.
///
/// `compared_ty` is the *inner* type (e.g. `str`, `[u8; 64]`);
/// `expect` and `val` must already be references to that type.
fn non_scalar_compare(
&mut self,
block: BasicBlock,
success_block: BasicBlock,
fail_block: BasicBlock,
source_info: SourceInfo,
compared_ty: Ty<'tcx>,
expect: Operand<'tcx>,
val: Operand<'tcx>,
) {
let str_ty = self.tcx.types.str_;
let eq_def_id = self.tcx.require_lang_item(LangItem::PartialEq, source_info.span);
let method = trait_method(self.tcx, eq_def_id, sym::eq, [str_ty, str_ty]);
let method = trait_method(self.tcx, eq_def_id, sym::eq, [compared_ty, compared_ty]);

let bool_ty = self.tcx.types.bool;
let eq_result = self.temp(bool_ty, source_info.span);
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// MIR for `array_match` after built

fn array_match(_1: [u8; 4]) -> bool {
debug x => _1;
let mut _0: bool;
let mut _2: &[u8; 4];
let mut _3: bool;
scope 1 {
}

bb0: {
PlaceMention(_1);
_2 = &_1;
_3 = <[u8; 4] as PartialEq>::eq(copy _2, const &*b"\x01\x02\x03\x04") -> [return: bb4, unwind: bb8];
}

bb1: {
_0 = const false;
goto -> bb7;
}

bb2: {
falseEdge -> [real: bb6, imaginary: bb1];
}

bb3: {
goto -> bb1;
}

bb4: {
switchInt(move _3) -> [0: bb1, otherwise: bb2];
}

bb5: {
FakeRead(ForMatchedPlace(None), _1);
unreachable;
}

bb6: {
_0 = const true;
goto -> bb7;
}

bb7: {
return;
}

bb8 (cleanup): {
resume;
}
}
Loading
Loading