Skip to content

Commit f5cebd6

Browse files
NL02copybara-github
authored andcommitted
[Explicit State Access] Support multiple State Reads nodes per State Element in proc_state_legalization_pass
Only apply state_read predicate legalization (incorporating next_value predicates) for state elements with exactly one state_read. Prevent nullptr from attempting to fetch State Reads from Next Nodes constructed from a State Element. PiperOrigin-RevId: 910694399
1 parent 4e4f4e8 commit f5cebd6

3 files changed

Lines changed: 109 additions & 8 deletions

File tree

xls/scheduling/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ cc_library(
459459
"//xls/ir:value",
460460
"//xls/passes:pass_base",
461461
"//xls/solvers:z3_ir_translator",
462+
"@com_google_absl//absl/algorithm:container",
462463
"@com_google_absl//absl/container:btree",
463464
"@com_google_absl//absl/container:flat_hash_set",
464465
"@com_google_absl//absl/log",

xls/scheduling/proc_state_legalization_pass.cc

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <variant>
2121
#include <vector>
2222

23+
#include "absl/algorithm/container.h"
2324
#include "absl/container/btree_set.h"
2425
#include "absl/container/flat_hash_set.h"
2526
#include "absl/log/check.h"
@@ -53,9 +54,17 @@ namespace {
5354
absl::StatusOr<bool> LegalizeStateReadPredicate(
5455
Proc* proc, StateElement* state_element,
5556
const SchedulingPassOptions& options) {
56-
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
57+
absl::Span<StateRead* const> reads =
58+
proc->GetStateReadsByStateElement(state_element);
5759
const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values =
5860
proc->next_values(state_element);
61+
62+
// Only legalize predicate if there is exactly one state_read.
63+
if (reads.size() != 1) {
64+
return false;
65+
}
66+
67+
StateRead* state_read = reads.front();
5968
if (!state_read->predicate().has_value() || next_values.empty()) {
6069
// Already unconditional, or no explicit `next_value`s; nothing to do.
6170
return false;
@@ -66,7 +75,7 @@ absl::StatusOr<bool> LegalizeStateReadPredicate(
6675
predicates.reserve(1 + next_values.size());
6776
predicates_set.reserve(next_values.size());
6877
for (Next* next : next_values) {
69-
if (next->state_read() == next->value()) {
78+
if (state_read == next->value()) {
7079
// This is a no-op next_value; we will narrow it to the case where the
7180
// state read is active instead.
7281
continue;
@@ -215,8 +224,12 @@ absl::StatusOr<bool> AddMutualExclusionAsserts(
215224
absl::StatusOr<bool> AddWriteWithoutReadAsserts(
216225
Proc* proc, StateElement* state_element,
217226
const SchedulingPassOptions& options) {
218-
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
219-
if (!state_read->predicate().has_value()) {
227+
absl::Span<StateRead* const> reads =
228+
proc->GetStateReadsByStateElement(state_element);
229+
230+
if (absl::c_any_of(reads, [](StateRead* state_read) {
231+
return !state_read->predicate().has_value();
232+
})) {
220233
return false;
221234
}
222235

@@ -226,7 +239,23 @@ absl::StatusOr<bool> AddWriteWithoutReadAsserts(
226239
return false;
227240
}
228241

229-
std::vector<Node*> predicate_list;
242+
std::vector<Node*> read_predicates;
243+
read_predicates.reserve(reads.size());
244+
for (StateRead* read : reads) {
245+
read_predicates.push_back(*read->predicate());
246+
}
247+
248+
Node* any_read_active;
249+
if (reads.size() == 1) {
250+
any_read_active = *reads[0]->predicate();
251+
} else {
252+
XLS_ASSIGN_OR_RETURN(
253+
any_read_active,
254+
proc->MakeNodeWithName<NaryOp>(
255+
SourceInfo(), read_predicates, Op::kOr,
256+
absl::StrCat("__", state_element->name(), "__any_read_active")));
257+
}
258+
230259
for (Next* next : next_values) {
231260
XLS_RET_CHECK(next->predicate().has_value());
232261
XLS_ASSIGN_OR_RETURN(
@@ -239,8 +268,7 @@ absl::StatusOr<bool> AddWriteWithoutReadAsserts(
239268
Node * no_write_without_read,
240269
proc->MakeNodeWithName<NaryOp>(
241270
SourceInfo(),
242-
absl::MakeConstSpan({*state_read->predicate(), next_not_triggered}),
243-
Op::kOr,
271+
absl::MakeConstSpan({any_read_active, next_not_triggered}), Op::kOr,
244272
absl::StrCat("__", state_element->name(), "__no_next_", next->id(),
245273
"_without_read")));
246274
std::string label = absl::StrCat("__", state_element->name(), "__next_",
@@ -291,7 +319,18 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
291319
StateElement* state_element,
292320
const SchedulingPassOptions& options) {
293321
absl::btree_set<Node*, Node::NodeIdLessThan> predicates;
294-
StateRead* state_read = proc->GetStateReadByStateElement(state_element);
322+
absl::Span<StateRead* const> reads =
323+
proc->GetStateReadsByStateElement(state_element);
324+
StateRead* state_read = nullptr;
325+
for (StateRead* read : reads) {
326+
if (!read->predicate().has_value()) {
327+
state_read = read;
328+
break;
329+
}
330+
}
331+
if (state_read == nullptr) {
332+
state_read = reads.front();
333+
}
295334
for (Next* next : proc->next_values(state_element)) {
296335
if (next->predicate().has_value()) {
297336
predicates.insert(*next->predicate());

xls/scheduling/proc_state_legalization_pass_test.cc

Lines changed: 61 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,67 @@ TEST_P(ProcStateLegalizationPassTest,
549549
m::Literal(0))))))));
550550
}
551551

552+
TEST_P(ProcStateLegalizationPassTest,
553+
ProcWithMultipleStateReadsAndDefaultNextValue) {
554+
auto p = CreatePackage();
555+
ProcBuilder pb("p", p.get());
556+
XLS_ASSERT_OK_AND_ASSIGN(
557+
StateElement * x_se,
558+
pb.proc()->AppendUnreadStateElement("x", Value(UBits(0, 32))));
559+
BValue x = pb.StateRead(x_se);
560+
561+
XLS_ASSERT_OK_AND_ASSIGN(
562+
StateElement * y_se,
563+
pb.proc()->AppendUnreadStateElement("y", Value(UBits(0, 32))));
564+
565+
BValue x_even =
566+
pb.Eq(pb.UMod(x, pb.Literal(UBits(2, 32))), pb.Literal(UBits(0, 32)));
567+
BValue x_multiple_of_3 =
568+
pb.Eq(pb.UMod(x, pb.Literal(UBits(3, 32))), pb.Literal(UBits(0, 32)));
569+
570+
BValue y_read1 = pb.StateRead(y_se, x_even);
571+
BValue y_read2 = pb.StateRead(y_se, x_multiple_of_3);
572+
573+
BValue y_val1 = pb.Add(y_read1, pb.Literal(UBits(3, 32)));
574+
BValue y_val2 = pb.Add(y_read2, pb.Literal(UBits(2, 32)));
575+
576+
BValue y_cond = pb.Select(x_even, {y_val2, y_val1});
577+
578+
BValue cond_write =
579+
pb.Eq(pb.UMod(x, pb.Literal(UBits(4, 32))), pb.Literal(UBits(0, 32)));
580+
pb.Next(y_se, pb.Add(y_cond, pb.Literal(UBits(1, 32))), cond_write);
581+
582+
pb.Next(x_se, pb.Add(x, pb.Literal(UBits(1, 32))));
583+
584+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build());
585+
XLS_ASSERT_OK(p->SetTop(proc));
586+
587+
ASSERT_THAT(Run(proc), IsOkAndHolds(true));
588+
589+
// 1. Verify the generated safety assertions
590+
std::vector<Node*> asserts;
591+
absl::c_copy_if(proc->nodes(), std::back_inserter(asserts),
592+
[](Node* node) { return node->Is<Assert>(); });
593+
594+
auto x_even_matcher =
595+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(2)), m::Literal(0));
596+
auto x_multiple_of_3_matcher =
597+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(3)), m::Literal(0));
598+
auto cond_write_matcher =
599+
m::Eq(m::UMod(m::StateRead("x"), m::Literal(4)), m::Literal(0));
600+
601+
EXPECT_THAT(asserts,
602+
Contains(m::Assert(
603+
_, m::Or(m::Or(x_even_matcher, x_multiple_of_3_matcher),
604+
m::Not(cond_write_matcher)))));
605+
606+
// 2. Verify the generated default next-state feedback loop
607+
EXPECT_THAT(
608+
proc->next_values(y_se),
609+
Contains(m::Next(m::StateRead("y"), m::StateRead("y"),
610+
m::And(x_even_matcher, m::Not(cond_write_matcher)))));
611+
}
612+
552613
INSTANTIATE_TEST_SUITE_P(ProcStateLegalizationPassTestSuite,
553614
ProcStateLegalizationPassTest,
554615
testing::Values(false, true));

0 commit comments

Comments
 (0)