Skip to content

Commit dfbcec2

Browse files
richmckeevercopybara-github
authored andcommitted
Support IR conversion of a simple parametric ProcDef.
PiperOrigin-RevId: 919066739
1 parent 367991f commit dfbcec2

15 files changed

Lines changed: 521 additions & 95 deletions

xls/dslx/ir_convert/conversion_record.cc

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,10 @@ std::string ConversionRecord::ToString() const {
104104
}
105105
return absl::StrFormat(
106106
"ConversionRecord{m=%s, f=%s, top=%s, pid=%s, "
107-
"parametric_env=%s, type_info=%p, config=%s impl=%s}",
107+
"parametric_env=%s, type_info=%s, config=%s impl=%s}",
108108
module_->name(), f_->identifier(), is_top_ ? "true" : "false", proc_id,
109-
parametric_env_.ToString(), type_info_, config,
109+
parametric_env_.ToString(),
110+
type_info_ == nullptr ? "<none>" : type_info_->name(), config,
110111
f_->impl().has_value() ? (*f_->impl())->struct_ref()->ToString()
111112
: "none");
112113
}

xls/dslx/ir_convert/function_converter.cc

Lines changed: 29 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3766,16 +3766,21 @@ absl::Status FunctionConverter::HandleChannelDecl(const ChannelDecl* node) {
37663766
return absl::OkStatus();
37673767
}
37683768

3769-
absl::Status FunctionConverter::InitProcDefBuilder(const ProcDef* proc_def) {
3770-
// TODO: https://github.com/google/xls/issues/4125 - Consider the parametrics
3771-
// and init values when invoking `MangleDslxName` here, using
3772-
// `HandleProcNextFunction` as a rough guide. ProcDef support is a WIP and we
3773-
// don't yet support parametrics.
3774-
ParametricEnv bindings;
3775-
XLS_ASSIGN_OR_RETURN(std::string mangled_name,
3776-
MangleDslxName(module_->name(), proc_def->identifier(),
3777-
CallingConvention::kProcNext,
3778-
/*free_keys=*/{}, &bindings));
3769+
absl::Status FunctionConverter::InitProcDefBuilder(const ProcDef* proc_def,
3770+
const ParametricEnv& env) {
3771+
absl::btree_set<std::string> parametric_keys;
3772+
3773+
// Include parametric values in the mangled names of non-top procs only.
3774+
if (!is_top_) {
3775+
for (const ParametricBinding* binding : proc_def->parametric_bindings()) {
3776+
parametric_keys.insert(binding->identifier());
3777+
}
3778+
}
3779+
3780+
XLS_ASSIGN_OR_RETURN(
3781+
std::string mangled_name,
3782+
MangleDslxName(module_->name(), proc_def->identifier(),
3783+
CallingConvention::kProcNext, parametric_keys, &env));
37793784
auto unique_builder =
37803785
std::make_unique<ProcBuilder>(NewStyleProc{}, mangled_name, package());
37813786
ProcBuilder* builder = unique_builder.get();
@@ -4139,26 +4144,29 @@ absl::Status FunctionConverter::AddProcDefInstantiation(
41394144

41404145
absl::Status FunctionConverter::ConvertProcDef(
41414146
const ProcDef* proc_def, const InterpValue& canonical_initializer_value,
4142-
ProcId proc_id, TypeInfo* type_info) {
4143-
VLOG(5) << "Converting ProcDef: " << proc_def->identifier();
4147+
ProcId proc_id, TypeInfo* constructor_ti, TypeInfo* next_ti,
4148+
const ParametricEnv& env) {
4149+
VLOG(5) << "Converting ProcDef: " << proc_def->identifier()
4150+
<< " using constructor TI " << constructor_ti->name()
4151+
<< " and next TI " << next_ti->name();
41444152

41454153
InterpValue::ProcInitializer canonical_initializer =
41464154
canonical_initializer_value.GetProcInitializerOrDie();
4147-
ScopedTypeInfoSwap stis(this, type_info);
4155+
ScopedTypeInfoSwap stis(this, constructor_ti);
41484156
proc_id_ = proc_id;
41494157
proc_data_->id_to_members[proc_id] = {};
41504158

4151-
XLS_RETURN_IF_ERROR(InitProcDefBuilder(proc_def));
4159+
XLS_RETURN_IF_ERROR(InitProcDefBuilder(proc_def, env));
41524160
XLS_RETURN_IF_ERROR(InitProcDefChannels(proc_def, canonical_initializer));
41534161
XLS_RETURN_IF_ERROR(
41544162
InitProcDefStateElements(proc_def, canonical_initializer));
41554163

41564164
XLS_ASSIGN_OR_RETURN(std::vector<InterpValue> spawnees,
4157-
type_info->GetProcDefSpawnsFrom(proc_def));
4165+
constructor_ti->GetProcDefSpawnsFrom(proc_def));
41584166
for (const InterpValue& spawnee : spawnees) {
4159-
const AstNode* definer = spawnee.GetProcInitializerOrDie().definer();
4160-
XLS_ASSIGN_OR_RETURN(InterpValue canonical_initializer,
4161-
current_type_info_->GetConstExpr(definer));
4167+
XLS_ASSIGN_OR_RETURN(
4168+
InterpValue canonical_initializer,
4169+
current_type_info_->GetCanonicalProcInitializer(spawnee));
41624170
XLS_RETURN_IF_ERROR(
41634171
AddProcDefInstantiation(proc_def, spawnee, canonical_initializer));
41644172
}
@@ -4168,7 +4176,9 @@ absl::Status FunctionConverter::ConvertProcDef(
41684176

41694177
XLS_ASSIGN_OR_RETURN(std::unique_ptr<ProcDefInstance> instance,
41704178
CreateProcDefInstance(proc_def));
4179+
41714180
SetNodeToIr((*next_fn)->params()[0]->name_def(), instance.get());
4181+
ScopedTypeInfoSwap stis_next(this, next_ti);
41724182
XLS_RETURN_IF_ERROR(Visit((*next_fn)->body()));
41734183

41744184
XLS_ASSIGN_OR_RETURN(
@@ -4179,7 +4189,7 @@ absl::Status FunctionConverter::ConvertProcDef(
41794189
// structure to support an InterpValue::ProcInitializer (or its whole
41804190
// contents + env) as a key.
41814191
package_data_.ir_to_dslx[p] = *next_fn;
4182-
package_data_.callee_to_ir_proc[{*next_fn, ParametricEnv{}}] = p;
4192+
package_data_.callee_to_ir_proc[{*next_fn, env}] = p;
41834193
package_data_
41844194
.canonical_proc_def_initializer_to_ir_proc[canonical_initializer_value] =
41854195
p;

xls/dslx/ir_convert/function_converter.h

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,7 +161,8 @@ class FunctionConverter {
161161

162162
absl::Status ConvertProcDef(const ProcDef* proc_def,
163163
const InterpValue& canonical_initializer_value,
164-
ProcId proc_id, TypeInfo* type_info);
164+
ProcId proc_id, TypeInfo* constructor_ti,
165+
TypeInfo* next_ti, const ParametricEnv& env);
165166

166167
// Notes a constant-definition dependency for the function (so it can
167168
// participate in the IR conversion).
@@ -544,7 +545,8 @@ class FunctionConverter {
544545
absl::Status HandleProcDef(const ProcDef* proc_def,
545546
const Function* constructor);
546547

547-
absl::Status InitProcDefBuilder(const ProcDef* proc_def);
548+
absl::Status InitProcDefBuilder(const ProcDef* proc_def,
549+
const ParametricEnv& env);
548550

549551
template <typename NodeType>
550552
absl::Status DefineProcDefChannelOrArray(

xls/dslx/ir_convert/get_conversion_records.cc

Lines changed: 32 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -63,14 +63,27 @@ absl::StatusOr<ConversionRecord> MakeConversionRecord(
6363

6464
// Variant for a `ProcDef` as opposed to a legacy `Proc`.
6565
absl::StatusOr<ConversionRecord> MakeConversionRecord(
66-
const ProcDef* p, ProcId proc_id, TypeInfo* ti, ParametricEnv env, bool top,
67-
std::optional<InterpValue> initializer = std::nullopt) {
66+
const ProcDef* p, ProcId proc_id, bool top,
67+
const ProcInitializerWithTypeInfo& canonical_initializer) {
6868
std::optional<Function*> next_fn = GetProcNextFunction(p);
6969
XLS_RET_CHECK(next_fn.has_value());
70-
return ConversionRecord::Make(*next_fn, p->owner(), ti, std::move(env),
71-
proc_id, top,
72-
/*config_record=*/nullptr,
73-
/*init_value=*/initializer, p);
70+
71+
// TODO: https://github.com/google/xls/issues/4125 - We are only using the
72+
// "config record" here as a place to store the TypeInfo for the constructor.
73+
// Legacy procs need more of the record. We should get rid of the separate
74+
// config record concept once we get rid of legacy procs.
75+
XLS_ASSIGN_OR_RETURN(
76+
ConversionRecord config_record,
77+
ConversionRecord::Make(*next_fn, p->owner(),
78+
canonical_initializer.constructor_type_info,
79+
canonical_initializer.constructor_env, proc_id,
80+
false, nullptr, std::nullopt, p));
81+
82+
return ConversionRecord::Make(
83+
*next_fn, p->owner(), canonical_initializer.next_type_info,
84+
canonical_initializer.constructor_env, proc_id, top,
85+
std::make_unique<ConversionRecord>(std::move(config_record)),
86+
/*init_value=*/canonical_initializer.initializer, p);
7487
}
7588

7689
// An AstNodeVisitor that creates ConversionRecords from appropriate AstNodes
@@ -244,7 +257,8 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
244257
// We can still convert this function even though it's never been called.
245258
// Make sure we are using the right type info for imported functions.
246259
TypeInfo* invocation_ti = GetTypeInfo(f);
247-
VLOG(5) << "Processing fn " << f->ToString();
260+
VLOG(5) << "Processing fn " << f->ToString() << " with TI "
261+
<< invocation_ti->name();
248262
XLS_ASSIGN_OR_RETURN(ConversionRecord cr,
249263
MakeConversionRecord(const_cast<Function*>(f),
250264
f->owner(), invocation_ti,
@@ -373,24 +387,29 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
373387
XLS_RET_CHECK(next_fn.has_value());
374388

375389
TypeInfo* proc_owner_ti = GetTypeInfo(p);
376-
XLS_ASSIGN_OR_RETURN(std::vector<InterpValue> canonical_initializers,
377-
proc_owner_ti->GetCanonicalProcInitializers(p));
390+
XLS_ASSIGN_OR_RETURN(
391+
std::vector<ProcInitializerWithTypeInfo> canonical_initializers,
392+
proc_owner_ti->GetCanonicalProcInitializers(p));
378393
if (p->IsParametric() && canonical_initializers.empty()) {
379394
VLOG(5) << "No calls to parametric proc " << p->name_def()->ToString();
380395
return absl::OkStatus();
381396
}
382-
for (const InterpValue& canonical_initializer : canonical_initializers) {
397+
for (const ProcInitializerWithTypeInfo& canonical_initializer :
398+
canonical_initializers) {
383399
// TODO: https://github.com/google/xls/issues/4125 - Exclude test-only
384400
// procs, and those only used in test-only contexts, if desired.
385401

386402
VLOG(5)
387403
<< "Making conversion record for canonical initializer of ProcDef: "
388-
<< canonical_initializer.ToString();
404+
<< canonical_initializer.initializer.ToString()
405+
<< " with constructor TI "
406+
<< canonical_initializer.constructor_type_info->name()
407+
<< " and next() TI " << canonical_initializer.next_type_info->name();
389408

390409
XLS_ASSIGN_OR_RETURN(
391410
ConversionRecord cr,
392-
MakeConversionRecord(p, proc_id_factory_->CreateProcId(p), type_info_,
393-
ParametricEnv(), /*top=*/top_ == *next_fn,
411+
MakeConversionRecord(p, proc_id_factory_->CreateProcId(p),
412+
/*top=*/top_ == *next_fn,
394413
canonical_initializer));
395414
records_.push_back(std::move(cr));
396415
}

xls/dslx/ir_convert/ir_converter.cc

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -190,6 +190,8 @@ absl::Status ConvertOneFunctionInternal(PackageData& package_data,
190190
ProcConversionData* proc_data,
191191
ChannelScope* channel_scope,
192192
const ConvertOptions& options) {
193+
VLOG(6) << "Convert record: " << record.ToString();
194+
193195
// Validate the requested conversion looks sound in terms of provided
194196
// parametrics.
195197
XLS_RETURN_IF_ERROR(ConversionRecord::ValidateParametrics(
@@ -206,8 +208,10 @@ absl::Status ConvertOneFunctionInternal(PackageData& package_data,
206208
}
207209

208210
if (record.proc_def().has_value()) {
209-
return converter.ConvertProcDef(*record.proc_def(), *record.init_value(),
210-
*record.proc_id(), record.type_info());
211+
return converter.ConvertProcDef(
212+
*record.proc_def(), *record.init_value(), *record.proc_id(),
213+
record.config_record()->type_info(), record.type_info(),
214+
record.parametric_env());
211215
}
212216

213217
Function* f = record.f();

xls/dslx/ir_convert/ir_converter_test.cc

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2444,6 +2444,70 @@ impl Main {
24442444
ExpectIr(conv.DumpIr());
24452445
}
24462446

2447+
TEST_F(IrConverterTest, ParametricProcDef) {
2448+
constexpr std::string_view kModule = R"(
2449+
#![feature(explicit_state_access)]
2450+
2451+
proc Loopback<N: u32> {
2452+
c_in: chan<uN[N]> in,
2453+
c_out: chan<uN[N]> out,
2454+
}
2455+
2456+
impl Loopback<N> {
2457+
fn new(c_in: chan<uN[N]> in, c_out: chan<uN[N]> out) -> Self {
2458+
Loopback { c_in, c_out }
2459+
}
2460+
2461+
fn next(self) {
2462+
let (t, val) = recv(join(), self.c_in);
2463+
send(t, self.c_out, val);
2464+
}
2465+
}
2466+
2467+
proc Main {
2468+
c_in: chan<u32> in,
2469+
c_out: chan<u32> out,
2470+
c_in_from_loopback: chan<u32> in,
2471+
c_out_to_loopback: chan<u32> out,
2472+
i: u32,
2473+
}
2474+
2475+
impl Main {
2476+
fn new(c_in: chan<u32> in, c_out: chan<u32> out) -> Self {
2477+
let (out_to_loopback, loopback_in) = chan<u32>("main_to_loopback");
2478+
let (loopback_out, in_from_loopback) = chan<u32>("loopback_to_main");
2479+
Loopback<32>::new(loopback_in, loopback_out).spawn();
2480+
2481+
Main {
2482+
c_in: c_in,
2483+
c_out: c_out,
2484+
c_in_from_loopback: in_from_loopback,
2485+
c_out_to_loopback: out_to_loopback,
2486+
i: 1
2487+
}
2488+
}
2489+
2490+
fn next(self) {
2491+
let i_val = read(self.i);
2492+
let (_, j) = recv(join(), self.c_in);
2493+
let loopback_tok = send(join(), self.c_out_to_loopback, j);
2494+
let (_, loopback_val) = recv(loopback_tok, self.c_in_from_loopback);
2495+
send(join(), self.c_out, i_val + loopback_val);
2496+
write(self.i, i_val + loopback_val);
2497+
}
2498+
}
2499+
)";
2500+
2501+
auto import_data = CreateImportDataForTest();
2502+
XLS_ASSERT_OK_AND_ASSIGN(
2503+
TypecheckedModule tm,
2504+
ParseAndTypecheck(kModule, "test_module.x", "test_module", &import_data));
2505+
XLS_ASSERT_OK_AND_ASSIGN(PackageConversionData conv,
2506+
ConvertModuleToPackage(tm.module, &import_data,
2507+
kProcScopedChannelOptions));
2508+
ExpectIr(conv.DumpIr());
2509+
}
2510+
24472511
TEST_F(IrConverterTest, TopProcDefWithIndrectConstructorResult) {
24482512
constexpr std::string_view program = R"(
24492513
#![feature(explicit_state_access)]
Lines changed: 68 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,68 @@
1+
package test_module
2+
3+
file_number 0 "test_module.x"
4+
5+
proc __test_module__Loopback__32_next<_c_in: bits[32] in, _c_out: bits[32] out>() {
6+
chan_interface _c_in(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
7+
chan_interface _c_out(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
8+
after_all.3: token = after_all(id=3)
9+
literal.2: bits[1] = literal(value=1, id=2)
10+
receive.4: (token, bits[32]) = receive(after_all.3, predicate=literal.2, channel=_c_in, id=4)
11+
t: token = tuple_index(receive.4, index=0, id=6)
12+
val: bits[32] = tuple_index(receive.4, index=1, id=7)
13+
__token: token = literal(value=token, id=1)
14+
tuple_index.5: token = tuple_index(receive.4, index=0, id=5)
15+
send.8: token = send(t, val, predicate=literal.2, channel=_c_out, id=8)
16+
tuple.9: () = tuple(id=9)
17+
}
18+
19+
proc __test_module__Main_next<_c_in: bits[32] in, _c_out: bits[32] out>(__i: bits[32], init={1}) {
20+
chan_interface _c_in(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
21+
chan_interface _c_out(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=ready_valid, flop_kind=none)
22+
chan _loopback_to_main(bits[32], id=0, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive)
23+
chan_interface _loopback_to_main(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
24+
chan_interface _loopback_to_main(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
25+
chan _main_to_loopback(bits[32], id=1, kind=streaming, ops=send_receive, flow_control=ready_valid, strictness=proven_mutually_exclusive)
26+
chan_interface _main_to_loopback(direction=send, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
27+
chan_interface _main_to_loopback(direction=receive, kind=streaming, strictness=proven_mutually_exclusive, flow_control=none, flop_kind=none)
28+
proc_instantiation __test_module__Loopback__32_next_inst(_main_to_loopback, _loopback_to_main, proc=__test_module__Loopback__32_next)
29+
after_all.23: token = after_all(id=23)
30+
literal.11: bits[1] = literal(value=1, id=11)
31+
receive.24: (token, bits[32]) = receive(after_all.23, predicate=literal.11, channel=_c_in, id=24)
32+
after_all.28: token = after_all(id=28)
33+
j: bits[32] = tuple_index(receive.24, index=1, id=27)
34+
literal.12: bits[1] = literal(value=0, id=12)
35+
loopback_tok: token = send(after_all.28, j, predicate=literal.11, channel=_main_to_loopback, id=29)
36+
not.15: bits[1] = not(literal.11, id=15)
37+
not.16: bits[1] = not(literal.12, id=16)
38+
__i__1: bits[32] = state_read(state_element=__i, predicate=literal.11, id=20)
39+
receive.30: (token, bits[32]) = receive(loopback_tok, predicate=literal.11, channel=_loopback_to_main, id=30)
40+
__token: token = literal(value=token, id=10)
41+
or.17: bits[1] = or(not.15, not.16, id=17)
42+
not.38: bits[1] = not(literal.11, id=38)
43+
or.19: bits[1] = or(literal.12, literal.11, id=19)
44+
literal.13: bits[1] = literal(value=0, id=13)
45+
i_val: bits[32] = identity(__i__1, id=21)
46+
loopback_val: bits[32] = tuple_index(receive.30, index=1, id=33)
47+
assert.18: token = assert(__token, or.17, message="State element read after read in same activation.", id=18)
48+
or.39: bits[1] = or(not.38, or.19, id=39)
49+
not.41: bits[1] = not(literal.11, id=41)
50+
not.42: bits[1] = not(literal.13, id=42)
51+
after_all.34: token = after_all(id=34)
52+
add.35: bits[32] = add(i_val, loopback_val, id=35)
53+
assert.40: token = assert(assert.18, or.39, message="State element written before read in same activation.", id=40)
54+
or.43: bits[1] = or(not.41, not.42, id=43)
55+
__i: bits[32] = state_read(state_element=__i, id=14)
56+
add.37: bits[32] = add(i_val, loopback_val, id=37)
57+
or.22: bits[1] = or(literal.12, literal.11, id=22)
58+
tuple_index.25: token = tuple_index(receive.24, index=0, id=25)
59+
tuple_index.26: token = tuple_index(receive.24, index=0, id=26)
60+
tuple_index.31: token = tuple_index(receive.30, index=0, id=31)
61+
tuple_index.32: token = tuple_index(receive.30, index=0, id=32)
62+
send.36: token = send(after_all.34, add.35, predicate=literal.11, channel=_c_out, id=36)
63+
assert.44: token = assert(assert.40, or.43, message="State element written after write in same activation.", id=44)
64+
or.45: bits[1] = or(literal.13, literal.11, id=45)
65+
next_value.46: () = next_value(param=__i, value=add.37, predicate=literal.11, id=46)
66+
tuple.47: () = tuple(id=47)
67+
tuple.48: () = tuple(id=48)
68+
}

0 commit comments

Comments
 (0)