@@ -285,6 +285,22 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
285285 return absl::OkStatus ();
286286 }
287287
288+ absl::Status HandleTypeRefTypeAnnotation (
289+ const TypeRefTypeAnnotation* node) override {
290+ AnyNameDef name_def =
291+ TypeDefinitionGetNameDef (node->type_ref ()->type_definition ());
292+ if (std::holds_alternative<const NameDef*>(name_def)) {
293+ const NameDef* nd = std::get<const NameDef*>(name_def);
294+ type_refs_[nd].emplace (node);
295+ }
296+ return DefaultHandler (node);
297+ }
298+
299+ absl::Status HandleStructInstance (const StructInstance* node) override {
300+ XLS_RETURN_IF_ERROR (node->struct_ref ()->Accept (this ));
301+ return DefaultHandler (node);
302+ }
303+
288304 absl::Status DefaultHandler (const AstNode* node) override {
289305 bool prev_in_type_annotation = in_type_annotation_;
290306 if (node->kind () == AstNodeKind::kTypeAnnotation ) {
@@ -298,7 +314,9 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
298314 }
299315 }
300316 for (const AstNode* child : node->GetChildren (/* want_types=*/ true )) {
301- XLS_RETURN_IF_ERROR (child->Accept (this ));
317+ if (child->owner () == node->owner ()) {
318+ XLS_RETURN_IF_ERROR (child->Accept (this ));
319+ }
302320 }
303321 in_type_annotation_ = prev_in_type_annotation;
304322 return absl::OkStatus ();
@@ -326,6 +344,23 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
326344 return result;
327345 }
328346
347+ absl::flat_hash_map<const NameDef*,
348+ absl::flat_hash_set<const TypeRefTypeAnnotation*>>
349+ TypesDefinedPrior (const Pos start) const {
350+ absl::flat_hash_map<const NameDef*,
351+ absl::flat_hash_set<const TypeRefTypeAnnotation*>>
352+ result;
353+ for (const auto & [name_def, type_refs] : type_refs_) {
354+ // Don't include top-level type definitions.
355+ std::optional<const ModuleMember*> module_member =
356+ name_def->owner ()->FindMemberWithName (name_def->identifier ());
357+ if (!module_member.has_value () && name_def->span ().start () < start) {
358+ result.emplace (name_def, type_refs);
359+ }
360+ }
361+ return result;
362+ }
363+
329364 private:
330365 struct NameRefInfo {
331366 bool any_used_in_type_annotation;
@@ -349,6 +384,9 @@ class CollectNameRefs : public AstNodeVisitorWithDefault {
349384 }
350385
351386 absl::flat_hash_map<const NameDef*, NameRefInfo> name_ref_info_;
387+ absl::flat_hash_map<const NameDef*,
388+ absl::flat_hash_set<const TypeRefTypeAnnotation*>>
389+ type_refs_;
352390 bool in_type_annotation_ = false ;
353391};
354392
@@ -392,7 +430,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
392430 std::vector<ExprOrType> struct_instance_parametrics;
393431 // NameDefs that have been added to the struct parametric bindings.
394432 absl::flat_hash_set<const NameDef*> parametric_nds;
395- absl::flat_hash_map<const NameRef *, NameRef *> name_ref_replacements ;
433+ absl::flat_hash_map<const AstNode *, AstNode *> node_replacements ;
396434
397435 // If there are any parametric bindings in the containing function that are
398436 // referenced in the lambda, they should be added as parametric bindings to
@@ -406,10 +444,20 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
406444 if (name_refs.empty ()) {
407445 continue ;
408446 }
409- XLS_RETURN_IF_ERROR (HandleParentParametric (
447+ XLS_RETURN_IF_ERROR (AddBindingForParentParametric (
410448 module , parent_binding, name_refs, struct_parametric_bindings,
411449 struct_type_parametrics, struct_instance_parametrics,
412- parametric_nds, name_ref_replacements));
450+ parametric_nds, node_replacements));
451+ }
452+ }
453+
454+ for (const auto & [original_nd, trtas] :
455+ collect_nr.TypesDefinedPrior (span.start ())) {
456+ if (!parametric_nds.contains (original_nd)) {
457+ XLS_RETURN_IF_ERROR (ReplaceTypeRefTypeAnnotations (
458+ module , original_nd, trtas, struct_parametric_bindings,
459+ struct_type_parametrics, struct_instance_parametrics,
460+ parametric_nds, node_replacements));
413461 }
414462 }
415463
@@ -422,9 +470,9 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
422470 for (const NameDef* original_name_def :
423471 collect_nr.NameDefsDefinedPrior (span.start ())) {
424472 if (!parametric_nds.contains (original_name_def)) {
425- HandleCapture (module , original_name_def, struct_parametric_bindings,
426- struct_type_parametrics, struct_members,
427- struct_instance_members, seen);
473+ AddCapture (module , original_name_def, struct_parametric_bindings,
474+ struct_type_parametrics, struct_members,
475+ struct_instance_members, seen);
428476 }
429477 }
430478
@@ -457,7 +505,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
457505 NameDef* self_nd = module ->Make <NameDef>(
458506 span, KeywordToString (Keyword::kSelf ), /* definer=*/ nullptr );
459507 CloneReplacer insert_self =
460- [self_nd, seen, name_ref_replacements ](
508+ [self_nd, seen, node_replacements ](
461509 const AstNode* node, const Module*,
462510 const absl::flat_hash_map<const AstNode*, AstNode*>&)
463511 -> std::optional<AstNode*> {
@@ -473,22 +521,20 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
473521 return node->owner ()->Make <Attr>(name_def->span (), self_nr,
474522 name_def->identifier (),
475523 /* in_parens= */ false );
476- } else if (name_ref_replacements.contains (name_ref)) {
477- return name_ref_replacements.at (name_ref);
478524 }
479525 }
526+ if (node_replacements.contains (node)) {
527+ return node_replacements.at (node);
528+ }
480529 return std::nullopt ;
481530 };
482- CloneReplacer swap_name_refs =
483- [name_ref_replacements ](
531+ CloneReplacer swap_nodes =
532+ [node_replacements ](
484533 const AstNode* node, const Module*,
485534 const absl::flat_hash_map<const AstNode*, AstNode*>&)
486535 -> std::optional<AstNode*> {
487- if (node->kind () == AstNodeKind::kNameRef ) {
488- const NameRef* name_ref = absl::down_cast<const NameRef*>(node);
489- if (name_ref_replacements.contains (name_ref)) {
490- return name_ref_replacements.at (name_ref);
491- }
536+ if (node_replacements.contains (node)) {
537+ return node_replacements.at (node);
492538 }
493539 return std::nullopt ;
494540 };
@@ -501,7 +547,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
501547 AstNode * cloned_return_type,
502548 CloneAst (original_fn->return_type (),
503549 ChainCloneReplacers (&PreserveTypeDefinitionsReplacer,
504- std::move (swap_name_refs ))));
550+ std::move (swap_nodes ))));
505551 SelfTypeAnnotation* self_type = module ->Make <SelfTypeAnnotation>(
506552 span, /* explicit_type=*/ false , struct_type_annotation);
507553 std::vector<Param*> params = {module ->Make <Param>(self_nd, self_type)};
@@ -551,14 +597,14 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
551597 }
552598
553599 private:
554- absl::Status HandleParentParametric (
600+ absl::Status AddBindingForParentParametric (
555601 Module* module , const ParametricBinding* parent_binding,
556602 absl::flat_hash_set<const NameRef*> name_refs,
557603 std::vector<ParametricBinding*>& struct_parametric_bindings,
558604 std::vector<ExprOrType>& struct_type_parametrics,
559605 std::vector<ExprOrType>& struct_instance_parametrics,
560606 absl::flat_hash_set<const NameDef*>& parametric_nds,
561- absl::flat_hash_map<const NameRef *, NameRef *>& name_ref_replacements ) {
607+ absl::flat_hash_map<const AstNode *, AstNode *>& node_replacements ) {
562608 NameDef* lambda_struct_nd = module ->Make <NameDef>(
563609 parent_binding->span (), parent_binding->identifier () + " _ls" ,
564610 /* definer=*/ nullptr );
@@ -590,7 +636,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
590636 }
591637 parametric_nds.insert (parent_binding->name_def ());
592638 for (const NameRef* original_name_ref : name_refs) {
593- name_ref_replacements .emplace (
639+ node_replacements .emplace (
594640 original_name_ref,
595641 module ->Make <NameRef>(original_name_ref->span (),
596642 lambda_struct_nd->identifier (),
@@ -599,7 +645,47 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
599645 return absl::OkStatus ();
600646 }
601647
602- void HandleCapture (
648+ absl::Status ReplaceTypeRefTypeAnnotations (
649+ Module* module , const NameDef* original_nd,
650+ absl::flat_hash_set<const TypeRefTypeAnnotation*> trtas,
651+ std::vector<ParametricBinding*>& struct_parametric_bindings,
652+ std::vector<ExprOrType>& struct_type_parametrics,
653+ std::vector<ExprOrType>& struct_instance_parametrics,
654+ absl::flat_hash_set<const NameDef*>& parametric_nds,
655+ absl::flat_hash_map<const AstNode*, AstNode*>& node_replacements) {
656+ NameDef* lambda_struct_nd = module ->Make <NameDef>(
657+ original_nd->span (),
658+ absl::Substitute (" $0_ls" , original_nd->identifier ()),
659+ /* definer=*/ nullptr );
660+ ParametricBinding* lambda_struct_binding = module ->Make <ParametricBinding>(
661+ lambda_struct_nd, module ->Make <GenericTypeAnnotation>(Span::None ()),
662+ /* expr=*/ nullptr );
663+ struct_parametric_bindings.push_back (lambda_struct_binding);
664+ NameRef* struct_type_parametric_nr = module ->Make <NameRef>(
665+ original_nd->span (), lambda_struct_nd->identifier (), lambda_struct_nd);
666+ struct_type_parametrics.push_back (struct_type_parametric_nr);
667+
668+ XLS_ASSIGN_OR_RETURN (TypeDefinition type_def,
669+ ToTypeDefinition (original_nd->definer ()));
670+ TypeRef* lambda_type_ref =
671+ module ->Make <TypeRef>(original_nd->span (), type_def);
672+
673+ struct_instance_parametrics.push_back (module ->Make <TypeRefTypeAnnotation>(
674+ original_nd->span (), lambda_type_ref, std::vector<ExprOrType>{}));
675+ parametric_nds.insert (original_nd);
676+ for (const TypeRefTypeAnnotation* original_type_ref : trtas) {
677+ node_replacements.emplace (
678+ original_type_ref,
679+ module ->Make <TypeVariableTypeAnnotation>(
680+ module ->Make <NameRef>(original_type_ref->span (),
681+ lambda_struct_nd->identifier (),
682+ lambda_struct_nd),
683+ /* internal=*/ true ));
684+ }
685+ return absl::OkStatus ();
686+ }
687+
688+ void AddCapture (
603689 Module* module , const NameDef* original_name_def,
604690 std::vector<ParametricBinding*>& struct_parametric_bindings,
605691 std::vector<ExprOrType>& struct_type_parametrics,
0 commit comments