Skip to content

Commit 50981d2

Browse files
erinzmoorecopybara-github
authored andcommitted
Start for supporting local type aliases in lambdas.
PiperOrigin-RevId: 919300513
1 parent b6d15bd commit 50981d2

2 files changed

Lines changed: 296 additions & 28 deletions

File tree

xls/dslx/frontend/semantics_analysis.cc

Lines changed: 108 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -285,6 +285,22 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
285285
return absl::OkStatus();
286286
}
287287

288+
absl::Status HandleTypeRefTypeAnnotation(
289+
const TypeRefTypeAnnotation* node) override {
290+
AnyNameDef name_def =
291+
TypeDefinitionGetNameDef(node->type_ref()->type_definition());
292+
if (std::holds_alternative<const NameDef*>(name_def)) {
293+
const NameDef* nd = std::get<const NameDef*>(name_def);
294+
type_refs_[nd].emplace(node);
295+
}
296+
return DefaultHandler(node);
297+
}
298+
299+
absl::Status HandleStructInstance(const StructInstance* node) override {
300+
XLS_RETURN_IF_ERROR(node->struct_ref()->Accept(this));
301+
return DefaultHandler(node);
302+
}
303+
288304
absl::Status DefaultHandler(const AstNode* node) override {
289305
bool prev_in_type_annotation = in_type_annotation_;
290306
if (node->kind() == AstNodeKind::kTypeAnnotation) {
@@ -298,7 +314,9 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
298314
}
299315
}
300316
for (const AstNode* child : node->GetChildren(/*want_types=*/true)) {
301-
XLS_RETURN_IF_ERROR(child->Accept(this));
317+
if (child->owner() == node->owner()) {
318+
XLS_RETURN_IF_ERROR(child->Accept(this));
319+
}
302320
}
303321
in_type_annotation_ = prev_in_type_annotation;
304322
return absl::OkStatus();
@@ -326,6 +344,23 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
326344
return result;
327345
}
328346

347+
absl::flat_hash_map<const NameDef*,
348+
absl::flat_hash_set<const TypeRefTypeAnnotation*>>
349+
TypesDefinedPrior(const Pos start) const {
350+
absl::flat_hash_map<const NameDef*,
351+
absl::flat_hash_set<const TypeRefTypeAnnotation*>>
352+
result;
353+
for (const auto& [name_def, type_refs] : type_refs_) {
354+
// Don't include top-level type definitions.
355+
std::optional<const ModuleMember*> module_member =
356+
name_def->owner()->FindMemberWithName(name_def->identifier());
357+
if (!module_member.has_value() && name_def->span().start() < start) {
358+
result.emplace(name_def, type_refs);
359+
}
360+
}
361+
return result;
362+
}
363+
329364
private:
330365
struct NameRefInfo {
331366
bool any_used_in_type_annotation;
@@ -349,6 +384,9 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
349384
}
350385

351386
absl::flat_hash_map<const NameDef*, NameRefInfo> name_ref_info_;
387+
absl::flat_hash_map<const NameDef*,
388+
absl::flat_hash_set<const TypeRefTypeAnnotation*>>
389+
type_refs_;
352390
bool in_type_annotation_ = false;
353391
};
354392

@@ -392,7 +430,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
392430
std::vector<ExprOrType> struct_instance_parametrics;
393431
// NameDefs that have been added to the struct parametric bindings.
394432
absl::flat_hash_set<const NameDef*> parametric_nds;
395-
absl::flat_hash_map<const NameRef*, NameRef*> name_ref_replacements;
433+
absl::flat_hash_map<const AstNode*, AstNode*> node_replacements;
396434

397435
// If there are any parametric bindings in the containing function that are
398436
// referenced in the lambda, they should be added as parametric bindings to
@@ -406,10 +444,20 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
406444
if (name_refs.empty()) {
407445
continue;
408446
}
409-
XLS_RETURN_IF_ERROR(HandleParentParametric(
447+
XLS_RETURN_IF_ERROR(AddBindingForParentParametric(
410448
module, parent_binding, name_refs, struct_parametric_bindings,
411449
struct_type_parametrics, struct_instance_parametrics,
412-
parametric_nds, name_ref_replacements));
450+
parametric_nds, node_replacements));
451+
}
452+
}
453+
454+
for (const auto& [original_nd, trtas] :
455+
collect_nr.TypesDefinedPrior(span.start())) {
456+
if (!parametric_nds.contains(original_nd)) {
457+
XLS_RETURN_IF_ERROR(ReplaceTypeRefTypeAnnotations(
458+
module, original_nd, trtas, struct_parametric_bindings,
459+
struct_type_parametrics, struct_instance_parametrics,
460+
parametric_nds, node_replacements));
413461
}
414462
}
415463

@@ -422,9 +470,9 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
422470
for (const NameDef* original_name_def :
423471
collect_nr.NameDefsDefinedPrior(span.start())) {
424472
if (!parametric_nds.contains(original_name_def)) {
425-
HandleCapture(module, original_name_def, struct_parametric_bindings,
426-
struct_type_parametrics, struct_members,
427-
struct_instance_members, seen);
473+
AddCapture(module, original_name_def, struct_parametric_bindings,
474+
struct_type_parametrics, struct_members,
475+
struct_instance_members, seen);
428476
}
429477
}
430478

@@ -457,7 +505,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
457505
NameDef* self_nd = module->Make<NameDef>(
458506
span, KeywordToString(Keyword::kSelf), /*definer=*/nullptr);
459507
CloneReplacer insert_self =
460-
[self_nd, seen, name_ref_replacements](
508+
[self_nd, seen, node_replacements](
461509
const AstNode* node, const Module*,
462510
const absl::flat_hash_map<const AstNode*, AstNode*>&)
463511
-> std::optional<AstNode*> {
@@ -473,22 +521,20 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
473521
return node->owner()->Make<Attr>(name_def->span(), self_nr,
474522
name_def->identifier(),
475523
/* in_parens= */ false);
476-
} else if (name_ref_replacements.contains(name_ref)) {
477-
return name_ref_replacements.at(name_ref);
478524
}
479525
}
526+
if (node_replacements.contains(node)) {
527+
return node_replacements.at(node);
528+
}
480529
return std::nullopt;
481530
};
482-
CloneReplacer swap_name_refs =
483-
[name_ref_replacements](
531+
CloneReplacer swap_nodes =
532+
[node_replacements](
484533
const AstNode* node, const Module*,
485534
const absl::flat_hash_map<const AstNode*, AstNode*>&)
486535
-> std::optional<AstNode*> {
487-
if (node->kind() == AstNodeKind::kNameRef) {
488-
const NameRef* name_ref = absl::down_cast<const NameRef*>(node);
489-
if (name_ref_replacements.contains(name_ref)) {
490-
return name_ref_replacements.at(name_ref);
491-
}
536+
if (node_replacements.contains(node)) {
537+
return node_replacements.at(node);
492538
}
493539
return std::nullopt;
494540
};
@@ -501,7 +547,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
501547
AstNode * cloned_return_type,
502548
CloneAst(original_fn->return_type(),
503549
ChainCloneReplacers(&PreserveTypeDefinitionsReplacer,
504-
std::move(swap_name_refs))));
550+
std::move(swap_nodes))));
505551
SelfTypeAnnotation* self_type = module->Make<SelfTypeAnnotation>(
506552
span, /*explicit_type=*/false, struct_type_annotation);
507553
std::vector<Param*> params = {module->Make<Param>(self_nd, self_type)};
@@ -551,14 +597,14 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
551597
}
552598

553599
private:
554-
absl::Status HandleParentParametric(
600+
absl::Status AddBindingForParentParametric(
555601
Module* module, const ParametricBinding* parent_binding,
556602
absl::flat_hash_set<const NameRef*> name_refs,
557603
std::vector<ParametricBinding*>& struct_parametric_bindings,
558604
std::vector<ExprOrType>& struct_type_parametrics,
559605
std::vector<ExprOrType>& struct_instance_parametrics,
560606
absl::flat_hash_set<const NameDef*>& parametric_nds,
561-
absl::flat_hash_map<const NameRef*, NameRef*>& name_ref_replacements) {
607+
absl::flat_hash_map<const AstNode*, AstNode*>& node_replacements) {
562608
NameDef* lambda_struct_nd = module->Make<NameDef>(
563609
parent_binding->span(), parent_binding->identifier() + "_ls",
564610
/*definer=*/nullptr);
@@ -590,7 +636,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
590636
}
591637
parametric_nds.insert(parent_binding->name_def());
592638
for (const NameRef* original_name_ref : name_refs) {
593-
name_ref_replacements.emplace(
639+
node_replacements.emplace(
594640
original_name_ref,
595641
module->Make<NameRef>(original_name_ref->span(),
596642
lambda_struct_nd->identifier(),
@@ -599,7 +645,47 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
599645
return absl::OkStatus();
600646
}
601647

602-
void HandleCapture(
648+
absl::Status ReplaceTypeRefTypeAnnotations(
649+
Module* module, const NameDef* original_nd,
650+
absl::flat_hash_set<const TypeRefTypeAnnotation*> trtas,
651+
std::vector<ParametricBinding*>& struct_parametric_bindings,
652+
std::vector<ExprOrType>& struct_type_parametrics,
653+
std::vector<ExprOrType>& struct_instance_parametrics,
654+
absl::flat_hash_set<const NameDef*>& parametric_nds,
655+
absl::flat_hash_map<const AstNode*, AstNode*>& node_replacements) {
656+
NameDef* lambda_struct_nd = module->Make<NameDef>(
657+
original_nd->span(),
658+
absl::Substitute("$0_ls", original_nd->identifier()),
659+
/*definer=*/nullptr);
660+
ParametricBinding* lambda_struct_binding = module->Make<ParametricBinding>(
661+
lambda_struct_nd, module->Make<GenericTypeAnnotation>(Span::None()),
662+
/*expr=*/nullptr);
663+
struct_parametric_bindings.push_back(lambda_struct_binding);
664+
NameRef* struct_type_parametric_nr = module->Make<NameRef>(
665+
original_nd->span(), lambda_struct_nd->identifier(), lambda_struct_nd);
666+
struct_type_parametrics.push_back(struct_type_parametric_nr);
667+
668+
XLS_ASSIGN_OR_RETURN(TypeDefinition type_def,
669+
ToTypeDefinition(original_nd->definer()));
670+
TypeRef* lambda_type_ref =
671+
module->Make<TypeRef>(original_nd->span(), type_def);
672+
673+
struct_instance_parametrics.push_back(module->Make<TypeRefTypeAnnotation>(
674+
original_nd->span(), lambda_type_ref, std::vector<ExprOrType>{}));
675+
parametric_nds.insert(original_nd);
676+
for (const TypeRefTypeAnnotation* original_type_ref : trtas) {
677+
node_replacements.emplace(
678+
original_type_ref,
679+
module->Make<TypeVariableTypeAnnotation>(
680+
module->Make<NameRef>(original_type_ref->span(),
681+
lambda_struct_nd->identifier(),
682+
lambda_struct_nd),
683+
/*internal=*/true));
684+
}
685+
return absl::OkStatus();
686+
}
687+
688+
void AddCapture(
603689
Module* module, const NameDef* original_name_def,
604690
std::vector<ParametricBinding*>& struct_parametric_bindings,
605691
std::vector<ExprOrType>& struct_type_parametrics,

0 commit comments

Comments
 (0)