Skip to content

Commit 801871b

Browse files
richmckeevercopybara-github
authored andcommitted
Fix bug where PSC IR conversion drops trait-derived functions too eagerly.
The problem was that when HandleInvocation would encounter an invocation like `c.to_bits()`, it would create a new visitor using the callee TypeInfo for that invocation, and call `new_visitor->HandleFunctionInternal(callee)`. The latter would then check if it was a derived function with no calls. But the call corresponding to the HandleInvocation in this scenario would not be in the callee TypeInfo. So this changes the logic to convert the function if we encounter it via a call. We also add better logging. PiperOrigin-RevId: 896099586
1 parent 4d862c3 commit 801871b

4 files changed

Lines changed: 49 additions & 9 deletions

File tree

xls/dslx/ir_convert/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,7 @@ cc_library(
531531
"@com_google_absl//absl/log",
532532
"@com_google_absl//absl/status",
533533
"@com_google_absl//absl/status:statusor",
534+
"@com_google_absl//absl/strings",
534535
"@com_google_absl//absl/strings:str_format",
535536
],
536537
)

xls/dslx/ir_convert/get_conversion_records.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,11 @@
1414

1515
#include "xls/dslx/ir_convert/get_conversion_records.h"
1616

17+
#include <cstdint>
1718
#include <ios>
1819
#include <memory>
1920
#include <optional>
21+
#include <string>
2022
#include <string_view>
2123
#include <utility>
2224
#include <variant>
@@ -26,6 +28,7 @@
2628
#include "absl/log/log.h"
2729
#include "absl/status/status.h"
2830
#include "absl/status/statusor.h"
31+
#include "absl/strings/str_cat.h"
2932
#include "absl/strings/str_format.h"
3033
#include "xls/common/status/ret_check.h"
3134
#include "xls/common/status/status_macros.h"
@@ -160,7 +163,8 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
160163
return absl::OkStatus();
161164
}
162165

163-
return HandleFunctionInternal(f, ParametricEnv{});
166+
return HandleFunctionInternal(f, ParametricEnv{},
167+
/*handle_for_invocation=*/false);
164168
}
165169

166170
// Handles all non-proc functions. For parametric functions, we only get here
@@ -169,7 +173,8 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
169173
// non-parametric, standalone functions, the `env` should be empty and the
170174
// `type_info_` on the visitor can be the root one.
171175
absl::Status HandleFunctionInternal(const Function* f,
172-
const ParametricEnv& env) {
176+
const ParametricEnv& env,
177+
bool handle_for_invocation) {
173178
XLS_RET_CHECK(module_ == f->owner());
174179
XLS_RET_CHECK(!f->IsInProc());
175180
if (f->IsParametric() || f->IsMethodOnParametricStruct()) {
@@ -181,9 +186,14 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
181186

182187
std::vector<InvocationCalleeData> calls =
183188
type_info_->GetUniqueInvocationCalleeData(f);
184-
if (f->IsCompilerDerived() && calls.empty()) {
185-
VLOG(5) << "No calls to derived function " << f->name_def()->ToString()
186-
<< "; not traversing for dependencies.";
189+
if (f->IsCompilerDerived() && calls.empty() && !handle_for_invocation) {
190+
std::string f_name = f->identifier();
191+
if (f->impl().has_value()) {
192+
f_name =
193+
absl::StrCat((*f->impl())->struct_ref()->ToString(), "::", f_name);
194+
}
195+
VLOG(5) << "No calls to derived function " << f_name << " in TI "
196+
<< type_info_->name() << "; not traversing for dependencies.";
187197
return absl::OkStatus();
188198
}
189199

@@ -247,8 +257,8 @@ class ConversionRecordVisitor : public AstNodeVisitorWithDefault {
247257
proc_id_factory_, top_, resolved_proc_alias_, records_,
248258
processed_invocations_);
249259

250-
XLS_RETURN_IF_ERROR(
251-
visitor.HandleFunctionInternal(call.callee, call.callee_bindings));
260+
XLS_RETURN_IF_ERROR(visitor.HandleFunctionInternal(
261+
call.callee, call.callee_bindings, /*handle_for_invocation=*/true));
252262

253263
VLOG(5) << "Processing invocation " << invocation->ToString();
254264
XLS_ASSIGN_OR_RETURN(

xls/dslx/ir_convert/ir_converter_test.cc

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8626,5 +8626,24 @@ proc main {
86268626
ExpectIr(converted);
86278627
}
86288628

8629+
TEST_F(IrConverterTest, DerivedToBitsNestedStructs) {
8630+
constexpr std::string_view kDslxProgram = R"(
8631+
#[derive(ToBits)]
8632+
pub struct A { x: u32 }
8633+
8634+
#[derive(ToBits)]
8635+
pub struct B { a: A }
8636+
8637+
#[derive(ToBits)]
8638+
pub struct C { b: B }
8639+
8640+
pub fn main(c: C) -> bits[32] {
8641+
c.to_bits()
8642+
}
8643+
)";
8644+
8645+
XLS_ASSERT_OK(ConvertModuleForTest(kDslxProgram));
8646+
}
8647+
86298648
} // namespace
86308649
} // namespace xls::dslx

xls/dslx/type_system_v2/inference_table_converter_impl.cc

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -683,8 +683,7 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
683683
XLS_ASSIGN_OR_RETURN(
684684
TypeInfo * invocation_type_info,
685685
import_data_.type_info_owner().New(
686-
function->owner(),
687-
absl::StrCat("invocation_of_", function->identifier()),
686+
function->owner(), CreateInvocationTypeInfoName(function),
688687
base_type_info));
689688

690689
std::optional<const StructDefBase*> target_struct = std::nullopt;
@@ -913,6 +912,17 @@ class InferenceTableConverterImpl : public InferenceTableConverter,
913912
return absl::OkStatus();
914913
}
915914

915+
// Generates a name for a `TypeInfo` object to be used for an invocation of
916+
// `f`.
917+
std::string CreateInvocationTypeInfoName(const Function* f) {
918+
if (f->impl().has_value()) {
919+
return absl::Substitute("invocation_of_$0::$1",
920+
(*f->impl())->struct_ref()->ToString(),
921+
f->identifier());
922+
}
923+
return absl::StrCat("invocation_of_", f->identifier());
924+
}
925+
916926
// Gets the output `TypeInfo` corresponding to the given
917927
// `parametric_context`, which may be `nullopt`, in which case it returns
918928
// the base type info. If the proposed type info doesn't belong to the

0 commit comments

Comments
 (0)