1515#include " xls/passes/array_untuple_pass.h"
1616
1717#include < cstdint>
18+ #include < optional>
1819#include < string_view>
1920#include < utility>
2021
2324#include " xls/common/fuzzing/fuzztest.h"
2425#include " absl/status/status_matchers.h"
2526#include " absl/status/statusor.h"
27+ #include " absl/types/span.h"
2628#include " xls/common/status/matchers.h"
2729#include " xls/fuzzer/ir_fuzzer/ir_fuzz_domain.h"
2830#include " xls/fuzzer/ir_fuzzer/ir_fuzz_test_library.h"
31+ #include " xls/interpreter/channel_queue.h"
32+ #include " xls/interpreter/interpreter_proc_runtime.h"
33+ #include " xls/interpreter/serial_proc_runtime.h"
2934#include " xls/ir/bits.h"
3035#include " xls/ir/channel_ops.h"
36+ #include " xls/ir/clone_package.h"
3137#include " xls/ir/function.h"
3238#include " xls/ir/function_builder.h"
3339#include " xls/ir/ir_matcher.h"
3440#include " xls/ir/ir_test_base.h"
3541#include " xls/ir/package.h"
3642#include " xls/ir/value.h"
3743#include " xls/ir/value_builder.h"
44+ #include " xls/ir/value_utils.h"
3845#include " xls/passes/dataflow_simplification_pass.h"
3946#include " xls/passes/dce_pass.h"
4047#include " xls/passes/optimization_pass.h"
@@ -45,6 +52,7 @@ namespace m = ::xls::op_matchers;
4552
4653namespace xls {
4754namespace {
55+
4856using ::absl_testing::IsOkAndHolds;
4957using ::testing::_;
5058using ::testing::IsSupersetOf;
@@ -458,6 +466,65 @@ TEST_F(ArrayUntuplePassTest, ProcStateArrayImplicitNext) {
458466 m::StateElement (_, m::Type (" bits[3][4]" ))}));
459467}
460468
469+ TEST_F (ArrayUntuplePassTest, ProcStateArrayWithInvoke) {
470+ auto p = CreatePackage ();
471+ Type* u1 = p->GetBitsType (1 );
472+ Type* u1_pair = p->GetTupleType ({u1, u1});
473+ Type* st_type = p->GetArrayType (1 , u1_pair);
474+
475+ XLS_ASSERT_OK_AND_ASSIGN (
476+ auto ch_out,
477+ p->CreateStreamingChannel (" ch_out" , ChannelOps::kSendOnly , st_type));
478+
479+ FunctionBuilder fb (" for_body" , p.get ());
480+ XLS_ASSERT_OK_AND_ASSIGN (Function * f,
481+ fb.BuildWithReturnValue (fb.Param (" st" , st_type)));
482+
483+ ProcBuilder pb (" inner" , p.get ());
484+ BValue acc = pb.StateElement (" acc" , ZeroOfType (st_type));
485+ BValue tok = pb.Literal (Value::Token ());
486+ BValue red = pb.Invoke (absl::MakeConstSpan ({acc}), f);
487+ BValue oa = pb.Literal (ValueBuilder::ArrayB ({
488+ ValueBuilder::Tuple ({Value (UBits (1 , 1 )), Value (UBits (1 , 1 ))}),
489+ }));
490+ pb.Send (ch_out, tok, red);
491+ pb.Next (acc, oa);
492+
493+ XLS_ASSERT_OK (pb.Build ().status ());
494+
495+ // 1. Evaluate BEFORE the pass runs on a clone of the package.
496+ XLS_ASSERT_OK_AND_ASSIGN (std::unique_ptr<Package> original_pkg,
497+ ClonePackage (p.get ()));
498+ XLS_ASSERT_OK_AND_ASSIGN (
499+ std::unique_ptr<SerialProcRuntime> eval_before,
500+ CreateInterpreterSerialProcRuntime (original_pkg.get ()));
501+ XLS_ASSERT_OK (eval_before->Tick ());
502+ XLS_ASSERT_OK (eval_before->Tick ());
503+ ChannelQueue& queue_before = eval_before->queue_manager ().GetQueue (
504+ original_pkg->GetChannel (" ch_out" ).value ());
505+ ASSERT_EQ (queue_before.GetSize (), 2 );
506+ std::optional<Value> val1_before = queue_before.Read ();
507+ std::optional<Value> val2_before = queue_before.Read ();
508+
509+ // 2. Run the pass (mutates package `p` in-place).
510+ ScopedRecordIr sri (p.get ());
511+ EXPECT_THAT (RunPass (p.get ()), IsOkAndHolds (false ));
512+
513+ // 3. Evaluate AFTER the pass runs.
514+ XLS_ASSERT_OK_AND_ASSIGN (std::unique_ptr<SerialProcRuntime> eval_after,
515+ CreateInterpreterSerialProcRuntime (p.get ()));
516+ XLS_ASSERT_OK (eval_after->Tick ());
517+ XLS_ASSERT_OK (eval_after->Tick ());
518+ ChannelQueue& queue_after = eval_after->queue_manager ().GetQueue (ch_out);
519+ ASSERT_EQ (queue_after.GetSize (), 2 );
520+ std::optional<Value> val1_after = queue_after.Read ();
521+ std::optional<Value> val2_after = queue_after.Read ();
522+
523+ // Confirm that the output values agree before and after the pass runs.
524+ EXPECT_EQ (val1_before, val1_after);
525+ EXPECT_EQ (val2_before, val2_after);
526+ }
527+
461528void IrFuzzArrayUntuple (FuzzPackageWithArgs fuzz_package_with_args) {
462529 ArrayUntuplePass pass;
463530 OptimizationPassChangesOutputs (std::move (fuzz_package_with_args), pass);
0 commit comments