Skip to content

Commit 520f35d

Browse files
NL02copybara-github
authored andcommitted
[Explicit State Access] Update xls/jit code to use GetStateReadsByStateElement and GetStateReads
PiperOrigin-RevId: 896008254
1 parent 801871b commit 520f35d

27 files changed

Lines changed: 513 additions & 281 deletions

xls/ir/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -931,6 +931,7 @@ cc_test(
931931
"@com_google_absl//absl/base",
932932
"@com_google_absl//absl/status",
933933
"@com_google_absl//absl/strings",
934+
"@com_google_absl//absl/types:span",
934935
"@googletest//:gtest",
935936
],
936937
)

xls/ir/ir_parser.cc

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1895,9 +1895,11 @@ absl::StatusOr<Parser::BodyResult> Parser::ParseBody(
18951895
Proc * source_proc,
18961896
ParseProc(package, /*outer_attributes=*/{}, &source));
18971897
for (StateElement* element : source_proc->StateElements()) {
1898-
name_to_value->emplace(
1899-
element->name(),
1900-
bb->SourceNode(source_proc->GetStateReadByStateElement(element)));
1898+
absl::Span<StateRead* const> reads =
1899+
source_proc->GetStateReadsByStateElement(element);
1900+
XLS_RET_CHECK_EQ(reads.size(), 1);
1901+
name_to_value->emplace(element->name(),
1902+
bb->SourceNode(reads.front()));
19011903
}
19021904
} else {
19031905
return absl::InvalidArgumentError(absl::StrFormat(

xls/ir/ir_parser_test.cc

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/strings/ascii.h"
2828
#include "absl/strings/str_cat.h"
2929
#include "absl/strings/substitute.h"
30+
#include "absl/types/span.h"
3031
#include "xls/common/source_location.h"
3132
#include "xls/common/status/matchers.h"
3233
#include "xls/ir/bits.h"
@@ -607,7 +608,9 @@ proc foo( x: bits[32], y: (), z: bits[32], init={42, (), 123}) {
607608
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, package->GetProc("foo"));
608609
EXPECT_EQ(proc->GetStateElementCount(), 3);
609610
XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x"));
610-
EXPECT_THAT(proc->GetStateReadByStateElement(x)->predicate(), std::nullopt);
611+
absl::Span<StateRead* const> reads = proc->GetStateReadsByStateElement(x);
612+
ASSERT_EQ(reads.size(), 1);
613+
EXPECT_THAT(reads.front()->predicate(), std::nullopt);
611614
}
612615

613616
TEST(IrParserTest, ProcWithPredicatedStateRead) {
@@ -626,17 +629,22 @@ proc foo( x: bits[32], y: bits[1], z: bits[32], init={42, 1, 123}) {
626629
EXPECT_EQ(proc->GetStateElementCount(), 3);
627630

628631
XLS_ASSERT_OK_AND_ASSIGN(StateElement * x, proc->GetStateElementByName("x"));
629-
std::optional<Node*> x_predicate =
630-
proc->GetStateReadByStateElement(x)->predicate();
632+
absl::Span<StateRead* const> reads_x = proc->GetStateReadsByStateElement(x);
633+
ASSERT_EQ(reads_x.size(), 1);
634+
std::optional<Node*> x_predicate = reads_x.front()->predicate();
631635
ASSERT_TRUE(x_predicate.has_value());
632636
ASSERT_EQ((*x_predicate)->op(), Op::kStateRead);
633637
EXPECT_EQ((*x_predicate)->As<StateRead>()->state_element()->name(), "y");
634638

635639
XLS_ASSERT_OK_AND_ASSIGN(StateElement * y, proc->GetStateElementByName("y"));
636-
ASSERT_FALSE(proc->GetStateReadByStateElement(y)->predicate().has_value());
640+
absl::Span<StateRead* const> reads_y = proc->GetStateReadsByStateElement(y);
641+
ASSERT_EQ(reads_y.size(), 1);
642+
ASSERT_FALSE(reads_y.front()->predicate().has_value());
637643

638644
XLS_ASSERT_OK_AND_ASSIGN(StateElement * z, proc->GetStateElementByName("z"));
639-
ASSERT_FALSE(proc->GetStateReadByStateElement(z)->predicate().has_value());
645+
absl::Span<StateRead* const> reads_z = proc->GetStateReadsByStateElement(z);
646+
ASSERT_EQ(reads_z.size(), 1);
647+
ASSERT_FALSE(reads_z.front()->predicate().has_value());
640648
}
641649

642650
TEST(IrParserTest, ParseSendReceiveChannel) {

xls/ir/node_util_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -381,7 +381,7 @@ TEST_F(NodeUtilTest, ChannelNodes) {
381381

382382
EXPECT_THAT(GetChannelUsedByNode(rcv.node()), IsOkAndHolds(ch0));
383383
EXPECT_THAT(GetChannelUsedByNode(send.node()), IsOkAndHolds(ch1));
384-
EXPECT_THAT(GetChannelUsedByNode(proc->GetStateRead(0)),
384+
EXPECT_THAT(GetChannelUsedByNode(proc->GetStateReads(0).front()),
385385
StatusIs(absl::StatusCode::kNotFound,
386386
HasSubstr("No channel associated with node")));
387387
}
@@ -435,7 +435,7 @@ TEST_F(NodeUtilTest, ReplaceTupleIndicesWorksWithToken) {
435435
// works, we'd need to make an after_all and add the receive's output token to
436436
// it after calling ReplaceTupleElementsWith().
437437
XLS_EXPECT_OK(ReplaceTupleElementsWith(
438-
receive_node, {{0, proc->GetStateRead(0)}, {1, lit0}}));
438+
receive_node, {{0, proc->GetStateReads(0).front()}, {1, lit0}}));
439439

440440
ExpectIr(proc->DumpIr(), TestName());
441441
}

xls/ir/proc.cc

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -214,13 +214,15 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
214214
StateElement* old_state_element = GetStateElement(index);
215215
auto old_state_read_it = state_reads_.find(old_state_element);
216216
XLS_RET_CHECK(old_state_read_it != state_reads_.end());
217-
if (!old_state_read_it->second->users().empty()) {
218-
return absl::InvalidArgumentError(absl::StrFormat(
219-
"Cannot remove state element %d of proc %s, existing "
220-
"state read %s has uses",
221-
index, name(), old_state_read_it->second->GetNameView()));
217+
for (StateRead* read : old_state_read_it->second) {
218+
if (!read->users().empty()) {
219+
return absl::InvalidArgumentError(
220+
absl::StrFormat("Cannot remove state element %d of proc %s, existing "
221+
"state read %s has uses",
222+
index, name(), read->GetNameView()));
223+
}
224+
XLS_RETURN_IF_ERROR(RemoveNode(read));
222225
}
223-
XLS_RETURN_IF_ERROR(RemoveNode(old_state_read_it->second));
224226
// TODO(allight): This should ideally not need to be done manually.
225227
state_reads_.erase(old_state_read_it);
226228

@@ -232,11 +234,14 @@ absl::Status Proc::RemoveStateElement(int64_t index) {
232234
absl::Status Proc::RemoveAllStateElements() {
233235
// TODO(allight): This relies on side tables being valid. For now just let it
234236
// go.
235-
for (const auto& [elem, read] : state_reads_) {
236-
if (read != nullptr) {
237-
XLS_RETURN_IF_ERROR(RemoveNode(read))
238-
<< "Cannot remove " << elem->ToString() << " of proc " << name()
239-
<< " because read '" << read->ToString() << "' could not be removed.";
237+
for (const auto& [elem, reads] : state_reads_) {
238+
for (StateRead* read : reads) {
239+
if (read != nullptr) {
240+
XLS_RETURN_IF_ERROR(RemoveNode(read))
241+
<< "Cannot remove " << elem->ToString() << " of proc " << name()
242+
<< " because read '" << read->ToString()
243+
<< "' could not be removed.";
244+
}
240245
}
241246
XLS_RETURN_IF_ERROR(state_name_uniquer_.ReleaseIdentifier(elem->name()))
242247
<< "Cannot release name of " << elem->ToString();
@@ -278,7 +283,7 @@ absl::StatusOr<StateRead*> Proc::InsertStateElement(
278283
MakeNodeWithName<StateRead>(
279284
loc, state_element, read_predicate,
280285
/*label=*/std::nullopt, state_element->name()));
281-
state_reads_[state_element] = state_read;
286+
state_reads_[state_element].push_back(state_read);
282287

283288
if (next_state.has_value()) {
284289
if (!ValueConformsToType(init_value, next_state.value()->GetType())) {
@@ -351,14 +356,13 @@ absl::StatusOr<Proc*> Proc::Clone(
351356
return mapping.at(orig);
352357
};
353358
for (StateElement* state_element : StateElements()) {
354-
StateRead* state_read = state_reads_.at(state_element);
355-
XLS_ASSIGN_OR_RETURN(
356-
StateRead * cloned_state_read,
357-
cloned_proc->AppendStateElement(
358-
remap_name(state_name_remapping, state_element->name()),
359-
state_element->initial_value(), state_read->predicate(),
360-
/*next_state=*/std::nullopt));
361-
original_to_clone[state_read] = cloned_state_read;
359+
XLS_RETURN_IF_ERROR(
360+
cloned_proc
361+
->InsertUnreadStateElement(
362+
cloned_proc->GetStateElementCount(),
363+
remap_name(state_name_remapping, state_element->name()),
364+
state_element->initial_value())
365+
.status());
362366
}
363367
if (is_new_style_proc()) {
364368
absl::flat_hash_map<ChannelInterface*, ChannelInterface*> channel_map;
@@ -445,7 +449,23 @@ absl::StatusOr<Proc*> Proc::Clone(
445449

446450
switch (node->op()) {
447451
case Op::kStateRead: {
448-
continue;
452+
StateRead* src = node->As<StateRead>();
453+
StateElement* src_elem = src->state_element();
454+
XLS_ASSIGN_OR_RETURN(int64_t idx, GetStateElementIndex(src_elem));
455+
StateElement* cloned_elem = cloned_proc->GetStateElement(idx);
456+
457+
std::optional<Node*> cloned_predicate;
458+
if (src->predicate().has_value()) {
459+
cloned_predicate = original_to_clone.at(src->predicate().value());
460+
}
461+
462+
XLS_ASSIGN_OR_RETURN(StateRead * cloned_state_read,
463+
cloned_proc->MakeNodeWithName<StateRead>(
464+
src->loc(), cloned_elem, cloned_predicate,
465+
/*label=*/std::nullopt, cloned_elem->name()));
466+
cloned_proc->state_reads_[cloned_elem].push_back(cloned_state_read);
467+
original_to_clone[node] = cloned_state_read;
468+
break;
449469
}
450470
case Op::kReceive: {
451471
Receive* src = node->As<Receive>();
@@ -1000,10 +1020,8 @@ absl::Status Proc::InternalRebuildSideTables() {
10001020
state_reads_.clear();
10011021
for (Node* n : nodes()) {
10021022
if (n->Is<StateRead>()) {
1003-
XLS_RET_CHECK(!state_reads_.contains(n->As<StateRead>()->state_element()))
1004-
<< "Duplicate state element read: "
1005-
<< n->As<StateRead>()->state_element();
1006-
state_reads_[n->As<StateRead>()->state_element()] = n->As<StateRead>();
1023+
state_reads_[n->As<StateRead>()->state_element()].push_back(
1024+
n->As<StateRead>());
10071025
} else if (n->Is<Next>()) {
10081026
next_values_.push_back(n->As<Next>());
10091027
next_values_by_state_element_[n->As<Next>()->state_element()].insert(

xls/ir/proc.h

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -95,10 +95,23 @@ class Proc : public FunctionBase {
9595
return state_elements_.contains(name);
9696
}
9797

98+
// Remove legacy getters after all downstream passes migrate logic.
9899
StateRead* GetStateRead(int64_t index) const {
99-
return state_reads_.at(GetStateElement(index));
100+
return GetStateReads(index).front();
100101
}
102+
101103
StateRead* GetStateReadByStateElement(StateElement* state_element) const {
104+
return GetStateReadsByStateElement(state_element).front();
105+
}
106+
107+
// Get state reads for a state element at the given index.
108+
absl::Span<StateRead* const> GetStateReads(int64_t index) const {
109+
return state_reads_.at(GetStateElement(index));
110+
}
111+
112+
// Get state reads for a state element.
113+
absl::Span<StateRead* const> GetStateReadsByStateElement(
114+
StateElement* state_element) const {
102115
return state_reads_.at(state_element);
103116
}
104117

@@ -403,8 +416,8 @@ class Proc : public FunctionBase {
403416
absl::flat_hash_map<std::string, std::unique_ptr<StateElement>>
404417
state_elements_;
405418

406-
// Map of the unique StateRead node for each state element.
407-
absl::flat_hash_map<StateElement*, StateRead*> state_reads_;
419+
// Map of StateRead nodes for each state element.
420+
absl::flat_hash_map<StateElement*, std::vector<StateRead*>> state_reads_;
408421

409422
// Vector of state element pointers. Kept in sync with the state_elements_
410423
// map. Enables easy, stable iteration over state elements. With this vector,

xls/ir/proc_test.cc

Lines changed: 54 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,43 @@ TEST_F(ProcTest, StatelessProc) {
205205
EXPECT_EQ(proc->DumpIr(), "proc p() {\n}\n");
206206
}
207207

208+
TEST_F(ProcTest, MultipleStateReads) {
209+
auto p = CreatePackage();
210+
ProcBuilder pb("p", p.get());
211+
BValue tkn = pb.StateElement("tkn", Value::Token());
212+
BValue state = pb.StateElement("x", Value(UBits(42, 32)));
213+
XLS_ASSERT_OK_AND_ASSIGN(Proc * proc, pb.Build({tkn, state}));
214+
215+
StateElement* state_elem = proc->GetStateElement(1);
216+
217+
EXPECT_EQ(proc->GetStateReads(1).size(), 1);
218+
StateRead* read1 = proc->GetStateReads(1).front();
219+
220+
// Second Read
221+
XLS_ASSERT_OK_AND_ASSIGN(
222+
StateRead * read2,
223+
proc->MakeNodeWithName<StateRead>(SourceInfo(), state_elem,
224+
/*predicate=*/std::nullopt,
225+
/*label=*/std::nullopt, "x_read2"));
226+
XLS_ASSERT_OK(proc->RebuildSideTables());
227+
228+
EXPECT_EQ(proc->GetStateReads(1).size(), 2);
229+
EXPECT_THAT(proc->GetStateReads(1), ElementsAre(read1, read2));
230+
231+
EXPECT_EQ(proc->GetStateReadsByStateElement(state_elem).size(), 2);
232+
EXPECT_THAT(proc->GetStateReadsByStateElement(state_elem),
233+
ElementsAre(read1, read2));
234+
235+
// Remove the second read.
236+
std::string read2_name = read2->GetName();
237+
XLS_ASSERT_OK(proc->RemoveNode(read2));
238+
XLS_ASSERT_OK(proc->RebuildSideTables());
239+
240+
// Now we should have 1 read again.
241+
EXPECT_EQ(proc->GetStateReads(1).size(), 1);
242+
EXPECT_EQ(proc->GetStateReads(1).front(), read1);
243+
}
244+
208245
TEST_F(ProcTest, RemoveStateThatStillHasUse) {
209246
// Don't call CreatePackage which creates a VerifiedPackage because we
210247
// intentionally create a malformed proc.
@@ -254,10 +291,10 @@ TEST_F(ProcTest, Clone) {
254291
EXPECT_EQ(clone->DumpIr(),
255292
R"(proc cloned(tkn: token, state: bits[32], init={token, 42}) {
256293
tkn: token = state_read(state_element=tkn, id=12)
257-
literal.14: bits[32] = literal(value=1, id=14)
258-
state: bits[32] = state_read(state_element=state, id=13)
294+
literal.13: bits[32] = literal(value=1, id=13)
295+
state: bits[32] = state_read(state_element=state, id=14)
259296
receive_3: (token, bits[32]) = receive(tkn, channel=cloned_chan, id=15)
260-
add.16: bits[32] = add(literal.14, state, id=16)
297+
add.16: bits[32] = add(literal.13, state, id=16)
261298
tuple_index.17: bits[32] = tuple_index(receive_3, index=1, id=17)
262299
tuple_index.18: token = tuple_index(receive_3, index=0, id=18)
263300
add.19: bits[32] = add(add.16, tuple_index.17, id=19)
@@ -304,10 +341,10 @@ proc cloned<input_chan: bits[32] in, chan: bits[32] out>(tkn: token, state: bits
304341
chan_interface input_chan(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
305342
chan_interface chan(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
306343
tkn: token = state_read(state_element=tkn, id=1)
307-
literal.3: bits[32] = literal(value=1, id=3)
308-
state: bits[32] = state_read(state_element=state, id=2)
344+
literal.2: bits[32] = literal(value=1, id=2)
345+
state: bits[32] = state_read(state_element=state, id=3)
309346
receive_3: (token, bits[32]) = receive(tkn, channel=input_chan, id=4)
310-
add.5: bits[32] = add(literal.3, state, id=5)
347+
add.5: bits[32] = add(literal.2, state, id=5)
311348
tuple_index.6: bits[32] = tuple_index(receive_3, index=1, id=6)
312349
tuple_index.7: token = tuple_index(receive_3, index=0, id=7)
313350
add.8: bits[32] = add(add.5, tuple_index.6, id=8)
@@ -355,15 +392,15 @@ TEST_F(ProcTest, CloneNewStyle) {
355392
chan baz(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive)
356393
chan_interface baz(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
357394
chan_interface baz(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
358-
tkn: token = literal(value=token, id=14)
359-
receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=15)
360-
tuple_index.16: token = tuple_index(receive_3, index=0, id=16)
361-
receive_6: (token, bits[32]) = receive(tuple_index.16, channel=baz, id=17)
362-
tuple_index.18: token = tuple_index(receive_6, index=0, id=18)
363-
state: bits[32] = state_read(state_element=state, id=13)
395+
tkn: token = literal(value=token, id=13)
396+
receive_3: (token, bits[32]) = receive(tkn, channel=foo, id=14)
397+
tuple_index.15: token = tuple_index(receive_3, index=0, id=15)
398+
receive_6: (token, bits[32]) = receive(tuple_index.15, channel=baz, id=16)
399+
tuple_index.17: token = tuple_index(receive_6, index=0, id=17)
400+
state: bits[32] = state_read(state_element=state, id=18)
364401
tuple_index.19: bits[32] = tuple_index(receive_3, index=1, id=19)
365402
tuple_index.20: bits[32] = tuple_index(receive_6, index=1, id=20)
366-
send_9: token = send(tuple_index.18, state, channel=bar, id=21)
403+
send_9: token = send(tuple_index.17, state, channel=bar, id=21)
367404
add.22: bits[32] = add(tuple_index.19, tuple_index.20, id=22)
368405
send_10: token = send(send_9, state, channel=baz, id=23)
369406
next_value.24: () = next_value(param=state, value=add.22, id=24)
@@ -556,7 +593,8 @@ TEST_F(ScheduledProcTest, StageAddAndClear) {
556593
proc->ClearStages();
557594
EXPECT_TRUE(proc->stages().empty());
558595
// Re-stage the state element to satisfy the verifier.
559-
XLS_ASSERT_OK(proc->AddNodeToStage(0, proc->GetStateRead(0)).status());
596+
XLS_ASSERT_OK(
597+
proc->AddNodeToStage(0, proc->GetStateReads(0).front()).status());
560598
}
561599

562600
TEST_F(ScheduledProcTest, AddEmptyStages) {
@@ -596,7 +634,8 @@ TEST_F(ScheduledProcTest, GetStageIndex) {
596634
EXPECT_THAT(proc->GetStageIndex(x), IsOkAndHolds(1));
597635
EXPECT_THAT(proc->GetStageIndex(y), IsOkAndHolds(2));
598636
EXPECT_THAT(proc->GetStageIndex(add), StatusIs(absl::StatusCode::kNotFound));
599-
EXPECT_THAT(proc->GetStageIndex(proc->GetStateRead(0)), IsOkAndHolds(0));
637+
EXPECT_THAT(proc->GetStageIndex(proc->GetStateReads(0).front()),
638+
IsOkAndHolds(0));
600639

601640
// The verifier requires that every node be in a stage before we finish.
602641
ASSERT_THAT(proc->AddNodeToStage(2, add), IsOkAndHolds(true));

xls/ir/proc_testutils.cc

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -244,11 +244,14 @@ absl::StatusOr<std::vector<BValue>> GetStateValuesBeforeActivation(
244244
absl::flat_hash_map<NodeActivation, BValue>& values) {
245245
std::vector<BValue> states;
246246
for (StateElement* state_element : p->StateElements()) {
247-
StateRead* state_read = p->GetStateReadByStateElement(state_element);
247+
absl::Span<StateRead* const> reads =
248+
p->GetStateReadsByStateElement(state_element);
249+
XLS_RET_CHECK(!reads.empty()) << "No reads for " << state_element;
250+
251+
BValue state_val;
248252
if (activation == 0) {
249-
values[{state_read, 0}] =
250-
fb.Literal(state_element->initial_value(), SourceInfo(),
251-
absl::StrFormat("%s_initial_value", p->name()));
253+
state_val = fb.Literal(state_element->initial_value(), SourceInfo(),
254+
absl::StrFormat("%s_initial_value", p->name()));
252255
} else {
253256
std::vector<BValue> cases;
254257
std::vector<BValue> selectors;
@@ -261,23 +264,26 @@ absl::StatusOr<std::vector<BValue>> GetStateValuesBeforeActivation(
261264
}
262265
if (selectors.empty()) {
263266
XLS_RET_CHECK_EQ(cases.size(), 1) << "no cases for " << state_element;
264-
values[{state_read, activation}] = cases.front();
267+
state_val = cases.front();
265268
} else if (cases.front().GetType()->IsBits() &&
266269
cases.front().GetType()->GetFlatBitCount() == 0) {
267270
// Special case to avoid creating non-trivial uses of zero-len bit
268271
// vectors.
269-
values[{state_read, activation}] = fb.Literal(UBits(0, 0));
272+
state_val = fb.Literal(UBits(0, 0));
270273
} else {
271274
XLS_RET_CHECK_EQ(cases.size(), selectors.size());
272275
// materialize the next values into a select.
273276
// Need to reverse to keep the LSB is case 0 etc.
274277
absl::c_reverse(selectors);
275-
values[{state_read, activation}] = fb.PrioritySelect(
278+
state_val = fb.PrioritySelect(
276279
fb.Concat(selectors), cases,
277-
/*default_value=*/values[{state_read, activation - 1}]);
280+
/*default_value=*/values[{reads.front(), activation - 1}]);
278281
}
279282
}
280-
states.push_back(values[{state_read, activation}]);
283+
for (StateRead* read : reads) {
284+
values[{read, activation}] = state_val;
285+
}
286+
states.push_back(state_val);
281287
}
282288
return states;
283289
}

0 commit comments

Comments
 (0)