@@ -930,26 +930,46 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
930930 Proc::StateElementTransformer& transform) {
931931 StateElement* old_state_element = old_state_read->state_element ();
932932 std::string orig_name (old_state_element->name ());
933- std::string orig_read_name (old_state_read->GetNameView ());
934- XLS_ASSIGN_OR_RETURN (std::optional<Node*> read_predicate,
935- transform.TransformReadPredicate (this , old_state_read));
933+
934+ absl::Span<StateRead* const > all_old_reads =
935+ GetStateReadsByStateElement (old_state_element);
936+ absl::flat_hash_map<StateRead*, std::string> orig_read_names;
937+ for (StateRead* old_read : all_old_reads) {
938+ orig_read_names[old_read] = std::string (old_read->GetNameView ());
939+ }
940+
936941 XLS_ASSIGN_OR_RETURN (
937- StateRead * new_state_read,
938- AppendStateElement (absl::StrFormat (" TEMP_NAME__%s__" , orig_name),
939- init_value, read_predicate,
940- /* next_state=*/ std::nullopt ));
941- new_state_read->SetLoc (old_state_read->loc ());
942- if (old_state_read->state_element ()->non_synthesizable ()) {
943- new_state_read->state_element ()->SetNonSynthesizable ();
942+ StateElement * new_state_element,
943+ AppendUnreadStateElement (absl::StrFormat (" TEMP_NAME__%s__" , orig_name),
944+ init_value, old_state_read->loc ()));
945+ if (old_state_element->non_synthesizable ()) {
946+ new_state_element->SetNonSynthesizable ();
944947 }
945- StateElement* new_state_element = new_state_read->state_element ();
946948 std::string temp_name = new_state_element->name ();
947949
948- XLS_ASSIGN_OR_RETURN (
949- Node * new_state_value,
950- transform.TransformStateRead (this , new_state_read, old_state_read));
951- std::vector<std::pair<Node*, Node*>> to_replace{
952- {old_state_read, new_state_value}};
950+ absl::flat_hash_map<StateRead*, StateRead*> old_to_new_read;
951+ StateRead* return_state_read = nullptr ;
952+ std::vector<std::pair<Node*, Node*>> to_replace;
953+
954+ for (StateRead* old_read : all_old_reads) {
955+ XLS_ASSIGN_OR_RETURN (std::optional<Node*> read_predicate,
956+ transform.TransformReadPredicate (this , old_read));
957+ XLS_ASSIGN_OR_RETURN (StateRead * new_read,
958+ MakeNodeWithName<StateRead>(
959+ old_read->loc (), new_state_element, read_predicate,
960+ old_read->label (), temp_name));
961+ state_reads_[new_state_element].push_back (new_read);
962+ old_to_new_read[old_read] = new_read;
963+
964+ if (old_read == old_state_read) {
965+ return_state_read = new_read;
966+ }
967+
968+ XLS_ASSIGN_OR_RETURN (Node * new_state_value, transform.TransformStateRead (
969+ this , new_read, old_read));
970+ to_replace.push_back ({old_read, new_state_value});
971+ }
972+
953973 struct NextTransformation {
954974 Next* old_next;
955975 Node* new_value;
@@ -959,14 +979,18 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
959979 for (Next* nxt : next_values (old_state_element)) {
960980 NextTransformation& new_next = transforms.emplace_back ();
961981 new_next.old_next = nxt;
962- XLS_ASSIGN_OR_RETURN (new_next.new_value , transform.TransformNextValue (
963- this , new_state_read, nxt));
964- XLS_RET_CHECK (new_next.new_value ->GetType () == new_state_read->GetType ())
982+ StateRead* corresponding_new_read =
983+ old_to_new_read.at (nxt->state_read ()->As <StateRead>());
984+ XLS_ASSIGN_OR_RETURN (
985+ new_next.new_value ,
986+ transform.TransformNextValue (this , corresponding_new_read, nxt));
987+ XLS_RET_CHECK (new_next.new_value ->GetType () ==
988+ corresponding_new_read->GetType ())
965989 << " New value is not compatible type. Expected: "
966- << new_state_read ->GetType () << " got " << new_next.new_value ;
990+ << corresponding_new_read ->GetType () << " got " << new_next.new_value ;
967991 XLS_ASSIGN_OR_RETURN (
968992 new_next.new_predicate ,
969- transform.TransformNextPredicate (this , new_state_read , nxt));
993+ transform.TransformNextPredicate (this , corresponding_new_read , nxt));
970994 }
971995
972996 // We've transformed all the graph elements. Start replacing them.
@@ -977,24 +1001,30 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
9771001 auto orig_storage = state_elements_.extract (orig_name);
9781002 orig_storage.key () = to_remove_name;
9791003 old_state_element->SetName (to_remove_name);
980- old_state_read->SetName (to_remove_name);
1004+ for (StateRead* old_read : all_old_reads) {
1005+ old_read->SetName (to_remove_name);
1006+ }
9811007 CHECK (state_elements_.insert (std::move (orig_storage)).inserted );
9821008
9831009 // Take over the old state element & read names.
9841010 auto new_storage = state_elements_.extract (temp_name);
9851011 new_storage.key () = orig_name;
9861012 new_state_element->SetName (orig_name);
987- new_state_read->SetNameDirectly (orig_read_name);
1013+ for (auto & [old_read, new_read] : old_to_new_read) {
1014+ new_read->SetNameDirectly (orig_read_names.at (old_read));
1015+ }
9881016 CHECK (state_elements_.insert (std::move (new_storage)).inserted );
9891017
9901018 // Identity-ify the old next nodes and create new ones.
9911019 for (const NextTransformation& nt : transforms) {
9921020 // Make the next
1021+ StateRead* corresponding_new_read =
1022+ old_to_new_read.at (nt.old_next ->state_read ()->As <StateRead>());
9931023 XLS_ASSIGN_OR_RETURN (
9941024 Next * nxt,
995- MakeNodeWithName<Next>(nt.old_next ->loc (), new_state_read, nt. new_value ,
996- nt.new_predicate , nt.old_next -> label () ,
997- nt.old_next ->GetName ()));
1025+ MakeNodeWithName<Next>(nt.old_next ->loc (), corresponding_new_read ,
1026+ nt.new_value , nt.new_predicate ,
1027+ nt.old_next ->label (), nt. old_next -> GetName ()));
9981028 to_replace.push_back ({nt.old_next , nxt});
9991029 // Identity-ify the old next.
10001030 XLS_RETURN_IF_ERROR (nt.old_next ->ReplaceOperandNumber (
@@ -1011,7 +1041,7 @@ absl::StatusOr<StateRead*> Proc::TransformStateElement(
10111041 },
10121042 /* replace_implicit_uses=*/ false ));
10131043 }
1014- return new_state_read ;
1044+ return return_state_read ;
10151045}
10161046
10171047absl::Status Proc::InternalRebuildSideTables () {
0 commit comments