2020#include < variant>
2121#include < vector>
2222
23+ #include " absl/algorithm/container.h"
2324#include " absl/container/btree_set.h"
2425#include " absl/container/flat_hash_set.h"
2526#include " absl/log/check.h"
@@ -53,9 +54,17 @@ namespace {
5354absl::StatusOr<bool > LegalizeStateReadPredicate (
5455 Proc* proc, StateElement* state_element,
5556 const SchedulingPassOptions& options) {
56- StateRead* state_read = proc->GetStateReadByStateElement (state_element);
57+ absl::Span<StateRead* const > reads =
58+ proc->GetStateReadsByStateElement (state_element);
5759 const absl::btree_set<Next*, Node::NodeIdLessThan>& next_values =
5860 proc->next_values (state_element);
61+
62+ // Only legalize predicate if there is exactly one state_read.
63+ if (reads.size () != 1 ) {
64+ return false ;
65+ }
66+
67+ StateRead* state_read = reads.front ();
5968 if (!state_read->predicate ().has_value () || next_values.empty ()) {
6069 // Already unconditional, or no explicit `next_value`s; nothing to do.
6170 return false ;
@@ -66,7 +75,7 @@ absl::StatusOr<bool> LegalizeStateReadPredicate(
6675 predicates.reserve (1 + next_values.size ());
6776 predicates_set.reserve (next_values.size ());
6877 for (Next* next : next_values) {
69- if (next-> state_read () == next->value ()) {
78+ if (state_read == next->value ()) {
7079 // This is a no-op next_value; we will narrow it to the case where the
7180 // state read is active instead.
7281 continue ;
@@ -215,8 +224,12 @@ absl::StatusOr<bool> AddMutualExclusionAsserts(
215224absl::StatusOr<bool > AddWriteWithoutReadAsserts (
216225 Proc* proc, StateElement* state_element,
217226 const SchedulingPassOptions& options) {
218- StateRead* state_read = proc->GetStateReadByStateElement (state_element);
219- if (!state_read->predicate ().has_value ()) {
227+ absl::Span<StateRead* const > reads =
228+ proc->GetStateReadsByStateElement (state_element);
229+
230+ if (absl::c_any_of (reads, [](StateRead* state_read) {
231+ return !state_read->predicate ().has_value ();
232+ })) {
220233 return false ;
221234 }
222235
@@ -226,7 +239,23 @@ absl::StatusOr<bool> AddWriteWithoutReadAsserts(
226239 return false ;
227240 }
228241
229- std::vector<Node*> predicate_list;
242+ std::vector<Node*> read_predicates;
243+ read_predicates.reserve (reads.size ());
244+ for (StateRead* read : reads) {
245+ read_predicates.push_back (*read->predicate ());
246+ }
247+
248+ Node* any_read_active;
249+ if (reads.size () == 1 ) {
250+ any_read_active = *reads[0 ]->predicate ();
251+ } else {
252+ XLS_ASSIGN_OR_RETURN (
253+ any_read_active,
254+ proc->MakeNodeWithName <NaryOp>(
255+ SourceInfo (), read_predicates, Op::kOr ,
256+ absl::StrCat (" __" , state_element->name (), " __any_read_active" )));
257+ }
258+
230259 for (Next* next : next_values) {
231260 XLS_RET_CHECK (next->predicate ().has_value ());
232261 XLS_ASSIGN_OR_RETURN (
@@ -239,8 +268,7 @@ absl::StatusOr<bool> AddWriteWithoutReadAsserts(
239268 Node * no_write_without_read,
240269 proc->MakeNodeWithName <NaryOp>(
241270 SourceInfo (),
242- absl::MakeConstSpan ({*state_read->predicate (), next_not_triggered}),
243- Op::kOr ,
271+ absl::MakeConstSpan ({any_read_active, next_not_triggered}), Op::kOr ,
244272 absl::StrCat (" __" , state_element->name (), " __no_next_" , next->id (),
245273 " _without_read" )));
246274 std::string label = absl::StrCat (" __" , state_element->name (), " __next_" ,
@@ -291,7 +319,18 @@ absl::StatusOr<bool> AddDefaultNextValue(Proc* proc,
291319 StateElement* state_element,
292320 const SchedulingPassOptions& options) {
293321 absl::btree_set<Node*, Node::NodeIdLessThan> predicates;
294- StateRead* state_read = proc->GetStateReadByStateElement (state_element);
322+ absl::Span<StateRead* const > reads =
323+ proc->GetStateReadsByStateElement (state_element);
324+ StateRead* state_read = nullptr ;
325+ for (StateRead* read : reads) {
326+ if (!read->predicate ().has_value ()) {
327+ state_read = read;
328+ break ;
329+ }
330+ }
331+ if (state_read == nullptr ) {
332+ state_read = reads.front ();
333+ }
295334 for (Next* next : proc->next_values (state_element)) {
296335 if (next->predicate ().has_value ()) {
297336 predicates.insert (*next->predicate ());
0 commit comments