Skip to content

Commit eb22bce

Browse files
NL02copybara-github
authored andcommitted
[Explicit State Access] Update xls/passes code to use GetStateReads and GetStateReadsByStateElement
This feature isn't usable as it will fail on passes with more than one state read. This CL is focused more on the refactor. PiperOrigin-RevId: 895997589
1 parent a0ecdeb commit eb22bce

18 files changed

Lines changed: 287 additions & 127 deletions

xls/ir/ir_matcher.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,11 @@ bool StateReadMatcher::MatchAndExplain(
628628
*listener << " has incorrect label";
629629
return false;
630630
}
631+
if (predicate_.has_value() &&
632+
!predicate_->MatchAndExplain(node->As<xls::StateRead>()->predicate(),
633+
listener)) {
634+
return false;
635+
}
631636
return true;
632637
}
633638

@@ -646,6 +651,12 @@ void StateReadMatcher::DescribeTo(::std::ostream* os) const {
646651
label_->DescribeTo(&ss);
647652
additional_fields.push_back(ss.str());
648653
}
654+
if (predicate_.has_value()) {
655+
std::stringstream ss;
656+
ss << "predicate=";
657+
predicate_->DescribeTo(&ss);
658+
additional_fields.push_back(ss.str());
659+
}
649660
DescribeToHelper(os, additional_fields);
650661
}
651662

xls/ir/ir_matcher.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,16 +1269,20 @@ inline ::testing::Matcher<const ::xls::Node*> OutputPort(
12691269
// EXPECT_THAT(x, m::StateRead());
12701270
// EXPECT_THAT(x, m::StateRead("x"));
12711271
// EXPECT_THAT(x, m::StateRead(HasSubstr("substr")));
1272+
// EXPECT_THAT(x, m::StateRead("x", /*predicate=*/m::Param("pred")));
12721273
//
12731274
class StateReadMatcher : public NodeMatcher {
12741275
public:
12751276
explicit StateReadMatcher(
12761277
std::optional<::testing::Matcher<const std::string>> state_element_name,
12771278
std::optional<::testing::Matcher<const std::optional<std::string>&>>
1278-
label = std::nullopt)
1279+
label = std::nullopt,
1280+
std::optional<::testing::Matcher<std::optional<Node*>>> predicate =
1281+
std::nullopt)
12791282
: NodeMatcher(Op::kStateRead, /*operands=*/{}),
12801283
state_element_name_(std::move(state_element_name)),
1281-
label_(std::move(label)) {}
1284+
label_(std::move(label)),
1285+
predicate_(std::move(predicate)) {}
12821286

12831287
bool MatchAndExplain(const Node* node,
12841288
::testing::MatchResultListener* listener) const override;
@@ -1287,6 +1291,7 @@ class StateReadMatcher : public NodeMatcher {
12871291
private:
12881292
std::optional<::testing::Matcher<const std::string>> state_element_name_;
12891293
std::optional<::testing::Matcher<const std::optional<std::string>&>> label_;
1294+
std::optional<::testing::Matcher<std::optional<Node*>>> predicate_;
12901295
};
12911296

12921297
template <typename T>
@@ -1324,6 +1329,14 @@ inline ::testing::Matcher<const ::xls::Node*> StateRead() {
13241329
return ::xls::op_matchers::StateReadMatcher(std::nullopt, std::nullopt);
13251330
}
13261331

1332+
inline ::testing::Matcher<const ::xls::Node*> StateRead(
1333+
::testing::Matcher<const std::string> name,
1334+
::testing::Matcher<const std::optional<std::string>&> label,
1335+
::testing::Matcher<std::optional<Node*>> predicate) {
1336+
return ::xls::op_matchers::StateReadMatcher(std::move(name), std::move(label),
1337+
std::move(predicate));
1338+
}
1339+
13271340
// Next matcher. Supported forms:
13281341
//
13291342
// EXPECT_THAT(x, m::Next());

xls/passes/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,6 +2507,7 @@ xls_pass(
25072507
"//xls/ir:type",
25082508
"//xls/ir:value",
25092509
"//xls/ir:value_utils",
2510+
"@com_google_absl//absl/algorithm:container",
25102511
"@com_google_absl//absl/container:btree",
25112512
"@com_google_absl//absl/container:flat_hash_map",
25122513
"@com_google_absl//absl/container:flat_hash_set",
@@ -2565,6 +2566,7 @@ xls_pass(
25652566
"@com_google_absl//absl/log",
25662567
"@com_google_absl//absl/status:statusor",
25672568
"@com_google_absl//absl/strings",
2569+
"@com_google_absl//absl/types:span",
25682570
],
25692571
)
25702572

@@ -2576,6 +2578,7 @@ xls_pass(
25762578
deps = [
25772579
":optimization_pass",
25782580
":pass_base",
2581+
"//xls/common/status:ret_check",
25792582
"//xls/common/status:status_macros",
25802583
"//xls/ir",
25812584
"//xls/ir:bits",

xls/passes/array_untuple_pass.cc

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,32 @@ absl::StatusOr<absl::flat_hash_set<Node*>> FindExternalGroups(
119119
// Don't mess with params that are only used in identity updates. Would
120120
// infinite loop otherwise since we don't remove these very often.
121121
for (StateElement* state_element : f->AsProcOrDie()->StateElements()) {
122-
StateRead* state_read =
123-
f->AsProcOrDie()->GetStateReadByStateElement(state_element);
124-
if (absl::c_all_of(state_read->users(), [&](Node* n) -> bool {
125-
if (n->Is<Next>()) {
122+
absl::Span<StateRead* const> state_reads =
123+
f->AsProcOrDie()->GetStateReadsByStateElement(state_element);
124+
bool all_reads_identity = true;
125+
for (StateRead* state_read : state_reads) {
126+
if (!absl::c_all_of(state_read->users(), [&](Node* n) -> bool {
127+
if (!n->Is<Next>()) {
128+
return false;
129+
}
126130
Next* nxt = n->As<Next>();
127-
return nxt->state_read() == nxt->value() &&
128-
nxt->state_read() == state_read;
129-
}
130-
return false;
131-
})) {
132-
excluded.insert(groups.Find(state_read));
131+
if (nxt->state_element() != state_element) {
132+
return false;
133+
}
134+
if (!nxt->value()->Is<StateRead>()) {
135+
return false;
136+
}
137+
return nxt->value()->As<StateRead>()->state_element() ==
138+
state_element;
139+
})) {
140+
all_reads_identity = false;
141+
break;
142+
}
143+
}
144+
if (all_reads_identity) {
145+
for (StateRead* state_read : state_reads) {
146+
excluded.insert(groups.Find(state_read));
147+
}
133148
}
134149
}
135150
}

xls/passes/array_untuple_pass_test.cc

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -458,6 +458,59 @@ TEST_F(ArrayUntuplePassTest, ProcStateArrayImplicitNext) {
458458
m::StateElement(_, m::Type("bits[3][4]"))}));
459459
}
460460

461+
TEST_F(ArrayUntuplePassTest, ProcStateArrayActiveMultiReadOptimization) {
462+
auto p = CreatePackage();
463+
XLS_ASSERT_OK_AND_ASSIGN(
464+
auto chan_resp, p->CreateStreamingChannel("resp", ChannelOps::kSendOnly,
465+
p->GetBitsType(3)));
466+
ProcBuilder pb(TestName(), p.get());
467+
BValue state = pb.StateElement(
468+
"foo", ValueBuilder::ArrayB({
469+
ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0, 1)),
470+
ValueBuilder::Bits(UBits(1, 3))}),
471+
}));
472+
473+
// Read #1: Extract the inner tuple value out of the array
474+
BValue extracted_tuple = pb.ArrayIndex(state, {pb.Literal(UBits(0, 1))});
475+
BValue extracted_data = pb.TupleIndex(extracted_tuple, 1);
476+
477+
// Send required by equivalence checker to observe proc mutations
478+
pb.Send(chan_resp, pb.Literal(Value::Token()), extracted_data);
479+
480+
// Read #2: Perform an independent update on the state
481+
BValue updated_tuple = pb.Tuple({pb.Literal(UBits(1, 1)), extracted_data});
482+
BValue next_state =
483+
pb.ArrayUpdate(state, updated_tuple, {pb.Literal(UBits(0, 1))});
484+
485+
XLS_ASSERT_OK_AND_ASSIGN(Proc * pr, pb.Build({next_state}));
486+
solvers::z3::ScopedVerifyProcEquivalence svpe(pr, /*activation_count=*/2,
487+
/*include_state=*/false);
488+
ScopedRecordIr sri(p.get());
489+
ASSERT_THAT(RunPass(p.get()), IsOkAndHolds(true));
490+
EXPECT_THAT(pr->StateElements(),
491+
IsSupersetOf({m::StateElement(_, m::Type("bits[1][1]")),
492+
m::StateElement(_, m::Type("bits[3][1]"))}));
493+
}
494+
495+
TEST_F(ArrayUntuplePassTest, ProcStateArrayIdentityUpdateOnly) {
496+
auto p = CreatePackage();
497+
ProcBuilder pb(TestName(), p.get());
498+
BValue state = pb.StateElement(
499+
"foo", ValueBuilder::ArrayB({
500+
ValueBuilder::Tuple({ValueBuilder::Bits(UBits(0, 1)),
501+
ValueBuilder::Bits((UBits(1, 3)))}),
502+
}));
503+
pb.Next(state, state);
504+
505+
XLS_ASSERT_OK(pb.Build().status());
506+
ScopedRecordIr sri(p.get());
507+
508+
ArrayUntuplePass pass;
509+
PassResults res;
510+
OptimizationContext ctx;
511+
ASSERT_THAT(pass.Run(p.get(), {}, &res, ctx), IsOkAndHolds(false));
512+
}
513+
461514
void IrFuzzArrayUntuple(FuzzPackageWithArgs fuzz_package_with_args) {
462515
ArrayUntuplePass pass;
463516
OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass);

xls/passes/canonicalization_pass_test.cc

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,12 @@ namespace xls {
4949
namespace {
5050

5151
using ::absl_testing::IsOkAndHolds;
52+
using ::testing::_;
5253
using ::testing::ElementsAre;
54+
using ::testing::Eq;
5355
using ::testing::IsEmpty;
5456
using ::testing::Optional;
57+
using ::testing::UnorderedElementsAre;
5558

5659
class CanonicalizePassTest : public IrTestBase {
5760
protected:
@@ -348,10 +351,13 @@ TEST_F(CanonicalizePassTest, StateReadWithAlwaysTruePredicate) {
348351
/*read_predicate=*/pb.Literal(UBits(1, 1)));
349352
pb.Next(x, pb.Literal(UBits(1, 32)));
350353
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
351-
EXPECT_THAT(proc->GetStateRead(0)->predicate(), Optional(m::Literal(1)));
354+
EXPECT_THAT(
355+
proc->GetStateReads(0).front(),
356+
UnorderedElementsAre(m::StateRead(_, _, Optional(m::Literal(1)))));
352357

353358
EXPECT_THAT(Run(p.get()), IsOkAndHolds(true));
354-
EXPECT_EQ(proc->GetStateRead(0)->predicate(), std::nullopt);
359+
EXPECT_THAT(proc->GetStateReads(0).front(),
360+
UnorderedElementsAre(m::StateRead(_, _, Eq(std::nullopt))));
355361
}
356362

357363
void IrFuzzCanonicalization(FuzzPackageWithArgs fuzz_package_with_args) {

xls/passes/conditional_specialization_pass_test.cc

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1266,14 +1266,16 @@ TEST_F(ConditionalSpecializationPassTest, StateReadSpecialization) {
12661266
Run(proc, /*use_bdd=*/true, /*optimize_for_best_case_throughput=*/true),
12671267
IsOkAndHolds(true));
12681268

1269-
EXPECT_THAT(
1270-
proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter0"))
1271-
->predicate(),
1272-
Optional(m::Not(m::StateRead("index"))));
1273-
EXPECT_THAT(
1274-
proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter1"))
1275-
->predicate(),
1276-
Optional(m::StateRead("index")));
1269+
EXPECT_THAT(proc->GetStateReadsByStateElement(
1270+
*proc->GetStateElementByName("counter0"))
1271+
.front()
1272+
->predicate(),
1273+
Optional(m::Not(m::StateRead("index"))));
1274+
EXPECT_THAT(proc->GetStateReadsByStateElement(
1275+
*proc->GetStateElementByName("counter1"))
1276+
.front()
1277+
->predicate(),
1278+
Optional(m::StateRead("index")));
12771279
}
12781280

12791281
TEST_F(ConditionalSpecializationPassTest, HarderStateReadSpecialization) {
@@ -1303,16 +1305,20 @@ TEST_F(ConditionalSpecializationPassTest, HarderStateReadSpecialization) {
13031305
Run(proc, /*use_bdd=*/true, /*optimize_for_best_case_throughput=*/true),
13041306
IsOkAndHolds(true));
13051307

1308+
EXPECT_THAT(proc->GetStateReadsByStateElement(
1309+
*proc->GetStateElementByName("counter0"))
1310+
.front()
1311+
->predicate(),
1312+
Optional(m::Eq(m::StateRead("index"), m::Literal(0))));
1313+
EXPECT_THAT(proc->GetStateReadsByStateElement(
1314+
*proc->GetStateElementByName("counter1"))
1315+
.front()
1316+
->predicate(),
1317+
Optional(m::Eq(m::StateRead("index"), m::Literal(1))));
13061318
EXPECT_THAT(
1307-
proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter0"))
1308-
->predicate(),
1309-
Optional(m::Eq(m::StateRead("index"), m::Literal(0))));
1310-
EXPECT_THAT(
1311-
proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter1"))
1312-
->predicate(),
1313-
Optional(m::Eq(m::StateRead("index"), m::Literal(1))));
1314-
EXPECT_THAT(
1315-
proc->GetStateReadByStateElement(*proc->GetStateElementByName("counter2"))
1319+
proc->GetStateReadsByStateElement(
1320+
*proc->GetStateElementByName("counter2"))
1321+
.front()
13161322
->predicate(),
13171323
Optional(m::And(
13181324
// High bit of index is set

xls/passes/proc_state_array_flattening_pass.cc

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include "absl/log/log.h"
2121
#include "absl/status/statusor.h"
2222
#include "absl/strings/str_cat.h"
23+
#include "absl/types/span.h"
2324
#include "xls/common/math_util.h"
2425
#include "xls/common/status/ret_check.h"
2526
#include "xls/common/status/status_macros.h"
@@ -131,10 +132,18 @@ absl::StatusOr<bool> SimplifyProcState(Proc* proc,
131132
Value new_init_value = Value::Tuple(old_init_value.elements());
132133

133134
ArrayToTupleStateTransformer transformer;
134-
XLS_RETURN_IF_ERROR(proc->TransformStateElement(
135-
proc->GetStateReadByStateElement(state_element),
136-
new_init_value, transformer)
137-
.status());
135+
absl::Span<StateRead* const> state_reads =
136+
proc->GetStateReadsByStateElement(state_element);
137+
138+
XLS_ASSIGN_OR_RETURN(StateRead * new_state_read,
139+
proc->TransformStateElement(
140+
state_reads.front(), new_init_value, transformer));
141+
for (int64_t i = 0; i < state_reads.size(); ++i) {
142+
XLS_ASSIGN_OR_RETURN(
143+
Node * replacement,
144+
transformer.TransformStateRead(proc, new_state_read, state_reads[i]));
145+
XLS_RETURN_IF_ERROR(state_reads[i]->ReplaceUsesWith(replacement));
146+
}
138147

139148
std::vector<Next*> old_next_values(proc->next_values(state_element).begin(),
140149
proc->next_values(state_element).end());

xls/passes/proc_state_bits_shattering_pass.cc

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/strings/str_format.h"
2828
#include "absl/types/span.h"
2929
#include "cppitertools/reversed.hpp"
30+
#include "xls/common/status/ret_check.h"
3031
#include "xls/common/status/status_macros.h"
3132
#include "xls/ir/bits.h"
3233
#include "xls/ir/node.h"
@@ -117,7 +118,12 @@ absl::StatusOr<bool> MaybeSplitStateElements(
117118
// to use STL set intersection algorithms.
118119
std::vector<int64_t> split_ends;
119120
bool could_benefit_from_splitting = false;
120-
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
121+
absl::Span<StateRead* const> state_reads =
122+
proc->GetStateReadsByStateElement(state_element);
123+
XLS_RET_CHECK_EQ(state_reads.size(), 1)
124+
<< "ProcStateBitsShatteringPass only supports one StateRead per "
125+
"StateElement for now.";
126+
StateRead* state_read = state_reads.front();
121127
for (Next* next : proc->next_values(state_element)) {
122128
if (next->value() == state_read) {
123129
// This is a no-op next-value; it doesn't affect whether or not it's

xls/passes/proc_state_narrowing_pass.cc

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,12 @@ absl::StatusOr<bool> ProcStateNarrowingPass::RunOnProcInternal(
161161
<< state_element->type()->ToString();
162162
continue;
163163
}
164-
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
164+
absl::Span<StateRead* const> state_reads =
165+
proc->GetStateReadsByStateElement(state_element);
166+
XLS_RET_CHECK_EQ(state_reads.size(), 1)
167+
<< "ProcStateNarrowingPass only supports one StateRead per "
168+
"StateElement for now.";
169+
StateRead* state_read = state_reads.front();
165170
std::optional<SharedLeafTypeTree<TernaryVector>> ternary =
166171
qe.GetTernary(state_read);
167172
if (!ternary) {

0 commit comments

Comments
 (0)