@@ -383,18 +383,21 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
383383 Span span = node->span ();
384384 CollectNameRefs collect_nr;
385385 XLS_RETURN_IF_ERROR (node->Accept (&collect_nr));
386- absl::flat_hash_set<const NameDef*> seen;
387386
388- // If there are any parametric bindings in the containing function that are
389- // referenced in the lambda, they should be added as parametric bindings to
390- // the `StructDef`.
391- std::vector<ParametricBinding*> struct_parametrics;
392- std::optional<const Function*> containing_fn = GetContainingFunction (node);
387+ // Parametric bindings for the struct definition.
388+ std::vector<ParametricBinding*> struct_parametric_bindings;
389+ // Parametrics in the struct type annotation.
393390 std::vector<ExprOrType> struct_type_parametrics;
391+ // Parametric values for the struct instantiation.
394392 std::vector<ExprOrType> struct_instance_parametrics;
393+ // NameDefs that have been added to the struct parametric bindings.
395394 absl::flat_hash_set<const NameDef*> parametric_nds;
396395 absl::flat_hash_map<const NameRef*, NameRef*> name_ref_replacements;
397396
397+ // If there are any parametric bindings in the containing function that are
398+ // referenced in the lambda, they should be added as parametric bindings to
399+ // the `StructDef`.
400+ std::optional<const Function*> containing_fn = GetContainingFunction (node);
398401 if (containing_fn.has_value ()) {
399402 for (ParametricBinding* parent_binding :
400403 (*containing_fn)->parametric_bindings ()) {
@@ -403,44 +406,10 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
403406 if (name_refs.empty ()) {
404407 continue ;
405408 }
406- NameDef* lambda_struct_nd = module ->Make <NameDef>(
407- parent_binding->span (), parent_binding->identifier () + " _ls" ,
408- /* definer=*/ nullptr );
409- XLS_ASSIGN_OR_RETURN (AstNode * cloned_ta,
410- CloneAst (parent_binding->type_annotation ()));
411-
412- AstNode* cloned_expr = nullptr ;
413- if (parent_binding->expr () != nullptr ) {
414- XLS_ASSIGN_OR_RETURN (cloned_expr, CloneAst (parent_binding->expr ()));
415- }
416- ParametricBinding* lambda_struct_binding =
417- module ->Make <ParametricBinding>(
418- lambda_struct_nd, absl::down_cast<TypeAnnotation*>(cloned_ta),
419- absl::down_cast<Expr*>(cloned_expr));
420- struct_parametrics.push_back (lambda_struct_binding);
421- NameRef* struct_type_parametric_nr = module ->Make <NameRef>(
422- parent_binding->span (), lambda_struct_nd->identifier (),
423- lambda_struct_nd);
424- NameRef* struct_instance_parametric_nr = module ->Make <NameRef>(
425- parent_binding->span (), parent_binding->identifier (),
426- parent_binding->name_def ());
427- struct_type_parametrics.push_back (struct_type_parametric_nr);
428- if (parent_binding->type_annotation ()
429- ->IsAnnotation <GenericTypeAnnotation>()) {
430- struct_instance_parametrics.push_back (
431- module ->Make <TypeVariableTypeAnnotation>(
432- struct_instance_parametric_nr));
433- } else {
434- struct_instance_parametrics.push_back (struct_instance_parametric_nr);
435- }
436- parametric_nds.insert (parent_binding->name_def ());
437- for (const NameRef* original_name_ref : name_refs) {
438- name_ref_replacements.emplace (
439- original_name_ref,
440- module ->Make <NameRef>(original_name_ref->span (),
441- lambda_struct_nd->identifier (),
442- lambda_struct_nd));
443- }
409+ XLS_RETURN_IF_ERROR (HandleParentParametric (
410+ module , parent_binding, name_refs, struct_parametric_bindings,
411+ struct_type_parametrics, struct_instance_parametrics,
412+ parametric_nds, name_ref_replacements));
444413 }
445414 }
446415
@@ -449,42 +418,13 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
449418 // already added as a parametric binding.
450419 std::vector<StructMemberNode*> struct_members;
451420 std::vector<std::pair<std::string, Expr*>> struct_instance_members;
421+ absl::flat_hash_set<const NameDef*> seen;
452422 for (const NameDef* original_name_def :
453423 collect_nr.NameDefsDefinedPrior (span.start ())) {
454424 if (!parametric_nds.contains (original_name_def)) {
455- // Create parametric binding with generic type to use for the context
456- // variable type.
457- GenericTypeAnnotation* gta =
458- module ->Make <GenericTypeAnnotation>(original_name_def->span ());
459- NameDef* generic_name_def = module ->Make <NameDef>(
460- original_name_def->span (),
461- absl::Substitute (" parametric_type_for_$0" ,
462- original_name_def->identifier ()),
463- /* definer=*/ gta);
464- NameRef* generic_name_ref = module ->Make <NameRef>(
465- original_name_def->span (), generic_name_def->identifier (),
466- generic_name_def);
467- struct_parametrics.push_back (module ->Make <ParametricBinding>(
468- generic_name_def, gta, /* expr=*/ nullptr ));
469-
470- NameDef* struct_member_nd = module ->Make <NameDef>(
471- original_name_def->span (), original_name_def->identifier (),
472- /* definer=*/ nullptr );
473- TypeVariableTypeAnnotation* tvta =
474- module ->Make <TypeVariableTypeAnnotation>(generic_name_ref,
475- /* internal=*/ true );
476- StructMemberNode* struct_member = module ->Make <StructMemberNode>(
477- Span::None (), struct_member_nd, Span::None (), tvta);
478- struct_members.push_back (struct_member);
479-
480- // Make a name ref that points to the original name def. Add as a member
481- // to a new struct instance.
482- NameRef* struct_instance_nr = module ->Make <NameRef>(
483- original_name_def->span (), original_name_def->identifier (),
484- original_name_def);
485- struct_instance_members.push_back (std::make_pair (
486- original_name_def->identifier (), struct_instance_nr));
487- seen.insert (original_name_def);
425+ HandleCapture (module , original_name_def, struct_parametric_bindings,
426+ struct_type_parametrics, struct_members,
427+ struct_instance_members, seen);
488428 }
489429 }
490430
@@ -494,7 +434,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
494434 span.ToString (file_table_)),
495435 /* definer=*/ nullptr );
496436 StructDef* full_struct_def =
497- module ->Make <StructDef>(span, struct_nd, struct_parametrics ,
437+ module ->Make <StructDef>(span, struct_nd, struct_parametric_bindings ,
498438 struct_members, /* is_public=*/ false );
499439 TypeRefTypeAnnotation* struct_type_annotation =
500440 module ->Make <TypeRefTypeAnnotation>(
@@ -611,6 +551,97 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
611551 }
612552
613553 private:
554+ absl::Status HandleParentParametric (
555+ Module* module , const ParametricBinding* parent_binding,
556+ absl::flat_hash_set<const NameRef*> name_refs,
557+ std::vector<ParametricBinding*>& struct_parametric_bindings,
558+ std::vector<ExprOrType>& struct_type_parametrics,
559+ std::vector<ExprOrType>& struct_instance_parametrics,
560+ absl::flat_hash_set<const NameDef*>& parametric_nds,
561+ absl::flat_hash_map<const NameRef*, NameRef*>& name_ref_replacements) {
562+ NameDef* lambda_struct_nd = module ->Make <NameDef>(
563+ parent_binding->span (), parent_binding->identifier () + " _ls" ,
564+ /* definer=*/ nullptr );
565+ XLS_ASSIGN_OR_RETURN (AstNode * cloned_ta,
566+ CloneAst (parent_binding->type_annotation ()));
567+
568+ AstNode* cloned_expr = nullptr ;
569+ if (parent_binding->expr () != nullptr ) {
570+ XLS_ASSIGN_OR_RETURN (cloned_expr, CloneAst (parent_binding->expr ()));
571+ }
572+ ParametricBinding* lambda_struct_binding = module ->Make <ParametricBinding>(
573+ lambda_struct_nd, absl::down_cast<TypeAnnotation*>(cloned_ta),
574+ absl::down_cast<Expr*>(cloned_expr));
575+ struct_parametric_bindings.push_back (lambda_struct_binding);
576+ NameRef* struct_type_parametric_nr =
577+ module ->Make <NameRef>(parent_binding->span (),
578+ lambda_struct_nd->identifier (), lambda_struct_nd);
579+ NameRef* struct_instance_parametric_nr = module ->Make <NameRef>(
580+ parent_binding->span (), parent_binding->identifier (),
581+ parent_binding->name_def ());
582+ struct_type_parametrics.push_back (struct_type_parametric_nr);
583+ if (parent_binding->type_annotation ()
584+ ->IsAnnotation <GenericTypeAnnotation>()) {
585+ struct_instance_parametrics.push_back (
586+ module ->Make <TypeVariableTypeAnnotation>(
587+ struct_instance_parametric_nr));
588+ } else {
589+ struct_instance_parametrics.push_back (struct_instance_parametric_nr);
590+ }
591+ parametric_nds.insert (parent_binding->name_def ());
592+ for (const NameRef* original_name_ref : name_refs) {
593+ name_ref_replacements.emplace (
594+ original_name_ref,
595+ module ->Make <NameRef>(original_name_ref->span (),
596+ lambda_struct_nd->identifier (),
597+ lambda_struct_nd));
598+ }
599+ return absl::OkStatus ();
600+ }
601+
602+ void HandleCapture (
603+ Module* module , const NameDef* original_name_def,
604+ std::vector<ParametricBinding*>& struct_parametric_bindings,
605+ std::vector<ExprOrType>& struct_type_parametrics,
606+ std::vector<StructMemberNode*>& struct_members,
607+ std::vector<std::pair<std::string, Expr*>>& struct_instance_members,
608+ absl::flat_hash_set<const NameDef*>& seen) {
609+ // Create parametric binding with generic type to use for the context
610+ // variable type.
611+ GenericTypeAnnotation* gta =
612+ module ->Make <GenericTypeAnnotation>(original_name_def->span ());
613+ NameDef* generic_name_def =
614+ module ->Make <NameDef>(original_name_def->span (),
615+ absl::Substitute (" parametric_type_for_$0" ,
616+ original_name_def->identifier ()),
617+ /* definer=*/ gta);
618+ NameRef* generic_name_ref =
619+ module ->Make <NameRef>(original_name_def->span (),
620+ generic_name_def->identifier (), generic_name_def);
621+ struct_type_parametrics.push_back (generic_name_ref);
622+ struct_parametric_bindings.push_back (module ->Make <ParametricBinding>(
623+ generic_name_def, gta, /* expr=*/ nullptr ));
624+
625+ NameDef* struct_member_nd = module ->Make <NameDef>(
626+ original_name_def->span (), original_name_def->identifier (),
627+ /* definer=*/ nullptr );
628+ TypeVariableTypeAnnotation* tvta =
629+ module ->Make <TypeVariableTypeAnnotation>(generic_name_ref,
630+ /* internal=*/ true );
631+ StructMemberNode* struct_member = module ->Make <StructMemberNode>(
632+ Span::None (), struct_member_nd, Span::None (), tvta);
633+ struct_members.push_back (struct_member);
634+
635+ // Make a name ref that points to the original name def. Add as a member
636+ // to a new struct instance.
637+ NameRef* struct_instance_nr = module ->Make <NameRef>(
638+ original_name_def->span (), original_name_def->identifier (),
639+ original_name_def);
640+ struct_instance_members.push_back (
641+ std::make_pair (original_name_def->identifier (), struct_instance_nr));
642+ seen.insert (original_name_def);
643+ }
644+
614645 const FileTable& file_table_;
615646};
616647
0 commit comments