Skip to content

Commit 64e0a03

Browse files
ericastorcopybara-github
authored andcommitted
[opt] Fix array_untuple bug by properly excluding unhandled node operands
When an unhandled node (e.g. `invoke`) returned an array of tuples, the external-group-finding analysis in `ArrayUntuplePass` incorrectly skipped checking the node's operands due to an early continue. This caused unsupported array-of-tuples arguments passed to that node to remain eligible for untupling, resulting in broken state mutations and incorrect correctness outcomes. This change removes the early continue, ensuring that operands of all unhandled nodes are correctly identified and excluded from untupling, preserving package semantics. #geminiassisted PiperOrigin-RevId: 914322574
1 parent 3d0955c commit 64e0a03

3 files changed

Lines changed: 73 additions & 1 deletion

File tree

xls/passes/BUILD

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4332,17 +4332,23 @@ cc_test(
43324332
"//xls/common/status:matchers",
43334333
"//xls/fuzzer/ir_fuzzer:ir_fuzz_domain",
43344334
"//xls/fuzzer/ir_fuzzer:ir_fuzz_test_library",
4335+
"//xls/interpreter:channel_queue",
4336+
"//xls/interpreter:interpreter_proc_runtime",
4337+
"//xls/interpreter:serial_proc_runtime",
43354338
"//xls/ir",
43364339
"//xls/ir:bits",
43374340
"//xls/ir:channel_ops",
4341+
"//xls/ir:clone_package",
43384342
"//xls/ir:function_builder",
43394343
"//xls/ir:ir_matcher",
43404344
"//xls/ir:ir_test_base",
43414345
"//xls/ir:value",
43424346
"//xls/ir:value_builder",
4347+
"//xls/ir:value_utils",
43434348
"//xls/solvers:z3_ir_equivalence_testutils",
43444349
"@com_google_absl//absl/status:status_matchers",
43454350
"@com_google_absl//absl/status:statusor",
4351+
"@com_google_absl//absl/types:span",
43464352
"@googletest//:gtest",
43474353
],
43484354
)

xls/passes/array_untuple_pass.cc

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,6 @@ absl::StatusOr<absl::flat_hash_set<Node*>> FindExternalGroups(
190190
VLOG(2) << "Unable to untuple " << n << " (in group: " << groups.Find(n)
191191
<< ")";
192192
excluded.insert(groups.Find(n));
193-
continue;
194193
}
195194
// We need to exclude this if the result is an array but we don't need to
196195
// exclude its operands.

xls/passes/array_untuple_pass_test.cc

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "xls/passes/array_untuple_pass.h"
1616

1717
#include <cstdint>
18+
#include <optional>
1819
#include <string_view>
1920
#include <utility>
2021

@@ -23,18 +24,24 @@
2324
#include "xls/common/fuzzing/fuzztest.h"
2425
#include "absl/status/status_matchers.h"
2526
#include "absl/status/statusor.h"
27+
#include "absl/types/span.h"
2628
#include "xls/common/status/matchers.h"
2729
#include "xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h"
2830
#include "xls/fuzzer/ir_fuzzer/ir_fuzz_test_library.h"
31+
#include "xls/interpreter/channel_queue.h"
32+
#include "xls/interpreter/interpreter_proc_runtime.h"
33+
#include "xls/interpreter/serial_proc_runtime.h"
2934
#include "xls/ir/bits.h"
3035
#include "xls/ir/channel_ops.h"
36+
#include "xls/ir/clone_package.h"
3137
#include "xls/ir/function.h"
3238
#include "xls/ir/function_builder.h"
3339
#include "xls/ir/ir_matcher.h"
3440
#include "xls/ir/ir_test_base.h"
3541
#include "xls/ir/package.h"
3642
#include "xls/ir/value.h"
3743
#include "xls/ir/value_builder.h"
44+
#include "xls/ir/value_utils.h"
3845
#include "xls/passes/dataflow_simplification_pass.h"
3946
#include "xls/passes/dce_pass.h"
4047
#include "xls/passes/optimization_pass.h"
@@ -45,6 +52,7 @@ namespace m = ::xls::op_matchers;
4552

4653
namespace xls {
4754
namespace {
55+
4856
using ::absl_testing::IsOkAndHolds;
4957
using ::testing::_;
5058
using ::testing::IsSupersetOf;
@@ -458,6 +466,65 @@ TEST_F(ArrayUntuplePassTest, ProcStateArrayImplicitNext) {
458466
m::StateElement(_, m::Type("bits[3][4]"))}));
459467
}
460468

469+
TEST_F(ArrayUntuplePassTest, ProcStateArrayWithInvoke) {
470+
auto p = CreatePackage();
471+
Type* u1 = p->GetBitsType(1);
472+
Type* u1_pair = p->GetTupleType({u1, u1});
473+
Type* st_type = p->GetArrayType(1, u1_pair);
474+
475+
XLS_ASSERT_OK_AND_ASSIGN(
476+
auto ch_out,
477+
p->CreateStreamingChannel("ch_out", ChannelOps::kSendOnly, st_type));
478+
479+
FunctionBuilder fb("for_body", p.get());
480+
XLS_ASSERT_OK_AND_ASSIGN(Function * f,
481+
fb.BuildWithReturnValue(fb.Param("st", st_type)));
482+
483+
ProcBuilder pb("inner", p.get());
484+
BValue acc = pb.StateElement("acc", ZeroOfType(st_type));
485+
BValue tok = pb.Literal(Value::Token());
486+
BValue red = pb.Invoke(absl::MakeConstSpan({acc}), f);
487+
BValue oa = pb.Literal(ValueBuilder::ArrayB({
488+
ValueBuilder::Tuple({Value(UBits(1, 1)), Value(UBits(1, 1))}),
489+
}));
490+
pb.Send(ch_out, tok, red);
491+
pb.Next(acc, oa);
492+
493+
XLS_ASSERT_OK(pb.Build().status());
494+
495+
// 1. Evaluate BEFORE the pass runs on a clone of the package.
496+
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<Package> original_pkg,
497+
ClonePackage(p.get()));
498+
XLS_ASSERT_OK_AND_ASSIGN(
499+
std::unique_ptr<SerialProcRuntime> eval_before,
500+
CreateInterpreterSerialProcRuntime(original_pkg.get()));
501+
XLS_ASSERT_OK(eval_before->Tick());
502+
XLS_ASSERT_OK(eval_before->Tick());
503+
ChannelQueue& queue_before = eval_before->queue_manager().GetQueue(
504+
original_pkg->GetChannel("ch_out").value());
505+
ASSERT_EQ(queue_before.GetSize(), 2);
506+
std::optional<Value> val1_before = queue_before.Read();
507+
std::optional<Value> val2_before = queue_before.Read();
508+
509+
// 2. Run the pass (mutates package `p` in-place).
510+
ScopedRecordIr sri(p.get());
511+
EXPECT_THAT(RunPass(p.get()), IsOkAndHolds(false));
512+
513+
// 3. Evaluate AFTER the pass runs.
514+
XLS_ASSERT_OK_AND_ASSIGN(std::unique_ptr<SerialProcRuntime> eval_after,
515+
CreateInterpreterSerialProcRuntime(p.get()));
516+
XLS_ASSERT_OK(eval_after->Tick());
517+
XLS_ASSERT_OK(eval_after->Tick());
518+
ChannelQueue& queue_after = eval_after->queue_manager().GetQueue(ch_out);
519+
ASSERT_EQ(queue_after.GetSize(), 2);
520+
std::optional<Value> val1_after = queue_after.Read();
521+
std::optional<Value> val2_after = queue_after.Read();
522+
523+
// Confirm that the output values agree before and after the pass runs.
524+
EXPECT_EQ(val1_before, val1_after);
525+
EXPECT_EQ(val2_before, val2_after);
526+
}
527+
461528
void IrFuzzArrayUntuple(FuzzPackageWithArgs fuzz_package_with_args) {
462529
ArrayUntuplePass pass;
463530
OptimizationPassChangesOutputs(std::move(fuzz_package_with_args), pass);

0 commit comments

Comments
 (0)