Skip to content

Commit 7be3286

Browse files
NL02copybara-github
authored andcommitted
[Explicit State Access] Update xls/passes code to use GetStateReads and GetStateReadsByStateElement
This feature isn't usable as it will fail on passes with more than one state read. This CL is focused more on the refactor. PiperOrigin-RevId: 895997589
1 parent 0a629bc commit 7be3286

14 files changed

Lines changed: 346 additions & 128 deletions

xls/ir/ir_matcher.cc

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -628,6 +628,11 @@ bool StateReadMatcher::MatchAndExplain(
628628
*listener << " has incorrect label";
629629
return false;
630630
}
631+
if (predicate_.has_value() &&
632+
!predicate_->MatchAndExplain(node->As<xls::StateRead>()->predicate(),
633+
listener)) {
634+
return false;
635+
}
631636
return true;
632637
}
633638

@@ -646,6 +651,12 @@ void StateReadMatcher::DescribeTo(::std::ostream* os) const {
646651
label_->DescribeTo(&ss);
647652
additional_fields.push_back(ss.str());
648653
}
654+
if (predicate_.has_value()) {
655+
std::stringstream ss;
656+
ss << "predicate=";
657+
predicate_->DescribeTo(&ss);
658+
additional_fields.push_back(ss.str());
659+
}
649660
DescribeToHelper(os, additional_fields);
650661
}
651662

xls/ir/ir_matcher.h

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1269,16 +1269,20 @@ inline ::testing::Matcher<const ::xls::Node*> OutputPort(
12691269
// EXPECT_THAT(x, m::StateRead());
12701270
// EXPECT_THAT(x, m::StateRead("x"));
12711271
// EXPECT_THAT(x, m::StateRead(HasSubstr("substr")));
1272+
// EXPECT_THAT(x, m::StateRead("x", /*predicate=*/m::Param("pred")));
12721273
//
12731274
class StateReadMatcher : public NodeMatcher {
12741275
public:
12751276
explicit StateReadMatcher(
12761277
std::optional<::testing::Matcher<const std::string>> state_element_name,
12771278
std::optional<::testing::Matcher<const std::optional<std::string>&>>
1278-
label = std::nullopt)
1279+
label = std::nullopt,
1280+
std::optional<::testing::Matcher<std::optional<Node*>>> predicate =
1281+
std::nullopt)
12791282
: NodeMatcher(Op::kStateRead, /*operands=*/{}),
12801283
state_element_name_(std::move(state_element_name)),
1281-
label_(std::move(label)) {}
1284+
label_(std::move(label)),
1285+
predicate_(std::move(predicate)) {}
12821286

12831287
bool MatchAndExplain(const Node* node,
12841288
::testing::MatchResultListener* listener) const override;
@@ -1287,6 +1291,7 @@ class StateReadMatcher : public NodeMatcher {
12871291
private:
12881292
std::optional<::testing::Matcher<const std::string>> state_element_name_;
12891293
std::optional<::testing::Matcher<const std::optional<std::string>&>> label_;
1294+
std::optional<::testing::Matcher<std::optional<Node*>>> predicate_;
12901295
};
12911296

12921297
template <typename T>
@@ -1324,6 +1329,14 @@ inline ::testing::Matcher<const ::xls::Node*> StateRead() {
13241329
return ::xls::op_matchers::StateReadMatcher(std::nullopt, std::nullopt);
13251330
}
13261331

1332+
inline ::testing::Matcher<const ::xls::Node*> StateRead(
1333+
::testing::Matcher<const std::string> name,
1334+
::testing::Matcher<const std::optional<std::string>&> label,
1335+
::testing::Matcher<std::optional<Node*>> predicate) {
1336+
return ::xls::op_matchers::StateReadMatcher(std::move(name), std::move(label),
1337+
std::move(predicate));
1338+
}
1339+
13271340
// Next matcher. Supported forms:
13281341
//
13291342
// EXPECT_THAT(x, m::Next());

xls/ir/proc.cc

Lines changed: 57 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -930,26 +930,46 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
930930
Proc::StateElementTransformer& transform) {
931931
StateElement* old_state_element = old_state_read->state_element();
932932
std::string orig_name(old_state_element->name());
933-
std::string orig_read_name(old_state_read->GetNameView());
934-
XLS_ASSIGN_OR_RETURN(std::optional<Node*> read_predicate,
935-
transform.TransformReadPredicate(this, old_state_read));
933+
934+
absl::Span<StateRead* const> all_old_reads =
935+
GetStateReadsByStateElement(old_state_element);
936+
absl::flat_hash_map<StateRead*, std::string> orig_read_names;
937+
for (StateRead* old_read : all_old_reads) {
938+
orig_read_names[old_read] = std::string(old_read->GetNameView());
939+
}
940+
936941
XLS_ASSIGN_OR_RETURN(
937-
StateRead * new_state_read,
938-
AppendStateElement(absl::StrFormat("TEMP_NAME__%s__", orig_name),
939-
init_value, read_predicate,
940-
/*next_state=*/std::nullopt));
941-
new_state_read->SetLoc(old_state_read->loc());
942-
if (old_state_read->state_element()->non_synthesizable()) {
943-
new_state_read->state_element()->SetNonSynthesizable();
942+
StateElement * new_state_element,
943+
AppendUnreadStateElement(absl::StrFormat("TEMP_NAME__%s__", orig_name),
944+
init_value, old_state_read->loc()));
945+
if (old_state_element->non_synthesizable()) {
946+
new_state_element->SetNonSynthesizable();
944947
}
945-
StateElement* new_state_element = new_state_read->state_element();
946948
std::string temp_name = new_state_element->name();
947949

948-
XLS_ASSIGN_OR_RETURN(
949-
Node * new_state_value,
950-
transform.TransformStateRead(this, new_state_read, old_state_read));
951-
std::vector<std::pair<Node*, Node*>> to_replace{
952-
{old_state_read, new_state_value}};
950+
absl::flat_hash_map<StateRead*, StateRead*> old_to_new_read;
951+
StateRead* return_state_read = nullptr;
952+
std::vector<std::pair<Node*, Node*>> to_replace;
953+
954+
for (StateRead* old_read : all_old_reads) {
955+
XLS_ASSIGN_OR_RETURN(std::optional<Node*> read_predicate,
956+
transform.TransformReadPredicate(this, old_read));
957+
XLS_ASSIGN_OR_RETURN(StateRead * new_read,
958+
MakeNodeWithName<StateRead>(
959+
old_read->loc(), new_state_element, read_predicate,
960+
old_read->label(), temp_name));
961+
state_reads_[new_state_element].push_back(new_read);
962+
old_to_new_read[old_read] = new_read;
963+
964+
if (old_read == old_state_read) {
965+
return_state_read = new_read;
966+
}
967+
968+
XLS_ASSIGN_OR_RETURN(Node * new_state_value, transform.TransformStateRead(
969+
this, new_read, old_read));
970+
to_replace.push_back({old_read, new_state_value});
971+
}
972+
953973
struct NextTransformation {
954974
Next* old_next;
955975
Node* new_value;
@@ -959,14 +979,18 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
959979
for (Next* nxt : next_values(old_state_element)) {
960980
NextTransformation& new_next = transforms.emplace_back();
961981
new_next.old_next = nxt;
962-
XLS_ASSIGN_OR_RETURN(new_next.new_value, transform.TransformNextValue(
963-
this, new_state_read, nxt));
964-
XLS_RET_CHECK(new_next.new_value->GetType() == new_state_read->GetType())
982+
StateRead* corresponding_new_read =
983+
old_to_new_read.at(nxt->state_read()->As<StateRead>());
984+
XLS_ASSIGN_OR_RETURN(
985+
new_next.new_value,
986+
transform.TransformNextValue(this, corresponding_new_read, nxt));
987+
XLS_RET_CHECK(new_next.new_value->GetType() ==
988+
corresponding_new_read->GetType())
965989
<< "New value is not compatible type. Expected: "
966-
<< new_state_read->GetType() << " got " << new_next.new_value;
990+
<< corresponding_new_read->GetType() << " got " << new_next.new_value;
967991
XLS_ASSIGN_OR_RETURN(
968992
new_next.new_predicate,
969-
transform.TransformNextPredicate(this, new_state_read, nxt));
993+
transform.TransformNextPredicate(this, corresponding_new_read, nxt));
970994
}
971995

972996
// We've transformed all the graph elements. Start replacing them.
@@ -977,24 +1001,30 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
9771001
auto orig_storage = state_elements_.extract(orig_name);
9781002
orig_storage.key() = to_remove_name;
9791003
old_state_element->SetName(to_remove_name);
980-
old_state_read->SetName(to_remove_name);
1004+
for (StateRead* old_read : all_old_reads) {
1005+
old_read->SetName(to_remove_name);
1006+
}
9811007
CHECK(state_elements_.insert(std::move(orig_storage)).inserted);
9821008

9831009
// Take over the old state element & read names.
9841010
auto new_storage = state_elements_.extract(temp_name);
9851011
new_storage.key() = orig_name;
9861012
new_state_element->SetName(orig_name);
987-
new_state_read->SetNameDirectly(orig_read_name);
1013+
for (auto& [old_read, new_read] : old_to_new_read) {
1014+
new_read->SetNameDirectly(orig_read_names.at(old_read));
1015+
}
9881016
CHECK(state_elements_.insert(std::move(new_storage)).inserted);
9891017

9901018
// Identity-ify the old next nodes and create new ones.
9911019
for (const NextTransformation& nt : transforms) {
9921020
// Make the next
1021+
StateRead* corresponding_new_read =
1022+
old_to_new_read.at(nt.old_next->state_read()->As<StateRead>());
9931023
XLS_ASSIGN_OR_RETURN(
9941024
Next * nxt,
995-
MakeNodeWithName<Next>(nt.old_next->loc(), new_state_read, nt.new_value,
996-
nt.new_predicate, nt.old_next->label(),
997-
nt.old_next->GetName()));
1025+
MakeNodeWithName<Next>(nt.old_next->loc(), corresponding_new_read,
1026+
nt.new_value, nt.new_predicate,
1027+
nt.old_next->label(), nt.old_next->GetName()));
9981028
to_replace.push_back({nt.old_next, nxt});
9991029
// Identity-ify the old next.
10001030
XLS_RETURN_IF_ERROR(nt.old_next->ReplaceOperandNumber(
@@ -1011,7 +1041,7 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
10111041
},
10121042
/*replace_implicit_uses=*/false));
10131043
}
1014-
return new_state_read;
1044+
return return_state_read;
10151045
}
10161046

10171047
absl::Status Proc::InternalRebuildSideTables() {

xls/ir/proc_test.cc

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -567,6 +567,75 @@ TEST_F(ProcTest, TransformStateElement) {
567567
EXPECT_THAT(user.node(), m::Tuple(m::Neg(new_st)));
568568
}
569569

570+
TEST_F(ProcTest, TransformStateElementMultipleReads) {
571+
auto p = CreatePackage();
572+
TokenlessProcBuilder pb(TestName(), "tkn", p.get());
573+
auto st = pb.StateElement("st", UBits(0b1010, 4));
574+
auto cond = pb.StateElement("cond", UBits(0, 1));
575+
auto add_st = pb.Next(st, pb.Add(st, pb.Literal(UBits(1, 4))), cond);
576+
pb.Next(cond, pb.Not(cond));
577+
578+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
579+
580+
// Manually add a second read for 'st'
581+
StateElement* st_elem = proc->GetStateElement(0);
582+
XLS_ASSERT_OK_AND_ASSIGN(
583+
StateRead * st_read2,
584+
proc->MakeNodeWithName<StateRead>(SourceInfo(), st_elem,
585+
/*predicate=*/std::nullopt,
586+
/*label=*/std::nullopt, "st_read2"));
587+
XLS_ASSERT_OK_AND_ASSIGN(
588+
Node * lit_sub,
589+
proc->MakeNode<Literal>(SourceInfo(), Value(UBits(2, 4))));
590+
XLS_ASSERT_OK_AND_ASSIGN(
591+
Node * sub_st2,
592+
proc->MakeNode<BinOp>(SourceInfo(), st_read2, lit_sub, Op::kSub));
593+
XLS_ASSERT_OK_AND_ASSIGN(
594+
Next * next_st2,
595+
proc->MakeNodeWithName<Next>(SourceInfo(), st_read2, sub_st2,
596+
/*predicate=*/std::nullopt,
597+
/*label=*/std::nullopt, "next_st2"));
598+
XLS_ASSERT_OK(proc->RebuildSideTables());
599+
600+
// Verify side tables
601+
EXPECT_EQ(proc->GetStateReadsByStateElement(st_elem).size(), 2);
602+
603+
// Test transformer (invert param)
604+
struct TestTransformer : public Proc::StateElementTransformer {
605+
public:
606+
absl::StatusOr<Node*> TransformStateRead(
607+
Proc* proc, StateRead* new_state_read,
608+
StateRead* old_state_read) override {
609+
return proc->MakeNode<UnOp>(new_state_read->loc(), new_state_read,
610+
Op::kNeg);
611+
}
612+
absl::StatusOr<Node*> TransformNextValue(Proc* proc,
613+
StateRead* new_state_read,
614+
Next* old_next) override {
615+
return proc->MakeNode<UnOp>(old_next->value()->loc(), old_next->value(),
616+
Op::kNeg);
617+
}
618+
};
619+
TestTransformer tt;
620+
XLS_ASSERT_OK_AND_ASSIGN(
621+
StateRead * new_st,
622+
proc->TransformStateElement(st.node()->As<StateRead>(),
623+
Value(UBits(0b0101, 4)), tt));
624+
625+
// Verify the first read and its next were transformed
626+
EXPECT_THAT(new_st, m::StateRead("st"));
627+
EXPECT_THAT(add_st.node(), m::Next(st.node(), st.node(), cond.node()));
628+
629+
// Verify the second read and its next were transformed
630+
Node* new_st_read2 = FindNode("st_read2", proc);
631+
ASSERT_NE(new_st_read2, nullptr);
632+
EXPECT_NE(new_st_read2, new_st);
633+
EXPECT_THAT(new_st_read2->As<StateRead>()->state_element(),
634+
new_st->state_element());
635+
636+
EXPECT_THAT(next_st2, m::Next(st_read2, st_read2));
637+
}
638+
570639
class ScheduledProcTest : public IrTestBase {
571640
protected:
572641
absl::StatusOr<ScheduledProc*> CreateScheduledProc(Package* p) {

xls/passes/BUILD

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2507,6 +2507,7 @@ xls_pass(
25072507
"//xls/ir:type",
25082508
"//xls/ir:value",
25092509
"//xls/ir:value_utils",
2510+
"@com_google_absl//absl/algorithm:container",
25102511
"@com_google_absl//absl/container:btree",
25112512
"@com_google_absl//absl/container:flat_hash_map",
25122513
"@com_google_absl//absl/container:flat_hash_set",
@@ -2565,6 +2566,7 @@ xls_pass(
25652566
"@com_google_absl//absl/log",
25662567
"@com_google_absl//absl/status:statusor",
25672568
"@com_google_absl//absl/strings",
2569+
"@com_google_absl//absl/types:span",
25682570
],
25692571
)
25702572

@@ -2576,6 +2578,7 @@ xls_pass(
25762578
deps = [
25772579
":optimization_pass",
25782580
":pass_base",
2581+
"//xls/common/status:ret_check",
25792582
"//xls/common/status:status_macros",
25802583
"//xls/ir",
25812584
"//xls/ir:bits",

xls/passes/array_untuple_pass.cc

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -119,17 +119,32 @@ absl::StatusOr<absl::flat_hash_set<Node*>> FindExternalGroups(
119119
// Don't mess with params that are only used in identity updates. Would
120120
// infinite loop otherwise since we don't remove these very often.
121121
for (StateElement* state_element : f->AsProcOrDie()->StateElements()) {
122-
StateRead* state_read =
123-
f->AsProcOrDie()->GetStateReadByStateElement(state_element);
124-
if (absl::c_all_of(state_read->users(), [&](Node* n) -> bool {
125-
if (n->Is<Next>()) {
122+
absl::Span<StateRead* const> state_reads =
123+
f->AsProcOrDie()->GetStateReadsByStateElement(state_element);
124+
bool all_reads_identity = true;
125+
for (StateRead* state_read : state_reads) {
126+
if (!absl::c_all_of(state_read->users(), [&](Node* n) -> bool {
127+
if (!n->Is<Next>()) {
128+
return false;
129+
}
126130
Next* nxt = n->As<Next>();
127-
return nxt->state_read() == nxt->value() &&
128-
nxt->state_read() == state_read;
129-
}
130-
return false;
131-
})) {
132-
excluded.insert(groups.Find(state_read));
131+
if (nxt->state_element() != state_element) {
132+
return false;
133+
}
134+
if (!nxt->value()->Is<StateRead>()) {
135+
return false;
136+
}
137+
return nxt->value()->As<StateRead>()->state_element() ==
138+
state_element;
139+
})) {
140+
all_reads_identity = false;
141+
break;
142+
}
143+
}
144+
if (all_reads_identity) {
145+
for (StateRead* state_read : state_reads) {
146+
excluded.insert(groups.Find(state_read));
147+
}
133148
}
134149
}
135150
}

0 commit comments

Comments
 (0)