Skip to content

Commit b6d15bd

Browse files
erinzmoorecopybara-github
authored andcommitted
Refactor HandleLambda in semantics_analysis.
* Factor out some helpers to add parametrics and context capture variables to the lambda. * Update the struct type annotation to include the parametrics for captured variables. PiperOrigin-RevId: 919245538
1 parent cac9fe8 commit b6d15bd

7 files changed

Lines changed: 141 additions & 110 deletions

xls/dslx/frontend/semantics_analysis.cc

Lines changed: 109 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -383,18 +383,21 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
383383
Span span = node->span();
384384
CollectNameRefs collect_nr;
385385
XLS_RETURN_IF_ERROR(node->Accept(&collect_nr));
386-
absl::flat_hash_set<const NameDef*> seen;
387386

388-
// If there are any parametric bindings in the containing function that are
389-
// referenced in the lambda, they should be added as parametric bindings to
390-
// the `StructDef`.
391-
std::vector<ParametricBinding*> struct_parametrics;
392-
std::optional<const Function*> containing_fn = GetContainingFunction(node);
387+
// Parametric bindings for the struct definition.
388+
std::vector<ParametricBinding*> struct_parametric_bindings;
389+
// Parametrics in the struct type annotation.
393390
std::vector<ExprOrType> struct_type_parametrics;
391+
// Parametric values for the struct instantiation.
394392
std::vector<ExprOrType> struct_instance_parametrics;
393+
// NameDefs that have been added to the struct parametric bindings.
395394
absl::flat_hash_set<const NameDef*> parametric_nds;
396395
absl::flat_hash_map<const NameRef*, NameRef*> name_ref_replacements;
397396

397+
// If there are any parametric bindings in the containing function that are
398+
// referenced in the lambda, they should be added as parametric bindings to
399+
// the `StructDef`.
400+
std::optional<const Function*> containing_fn = GetContainingFunction(node);
398401
if (containing_fn.has_value()) {
399402
for (ParametricBinding* parent_binding :
400403
(*containing_fn)->parametric_bindings()) {
@@ -403,44 +406,10 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
403406
if (name_refs.empty()) {
404407
continue;
405408
}
406-
NameDef* lambda_struct_nd = module->Make<NameDef>(
407-
parent_binding->span(), parent_binding->identifier() + "_ls",
408-
/*definer=*/nullptr);
409-
XLS_ASSIGN_OR_RETURN(AstNode * cloned_ta,
410-
CloneAst(parent_binding->type_annotation()));
411-
412-
AstNode* cloned_expr = nullptr;
413-
if (parent_binding->expr() != nullptr) {
414-
XLS_ASSIGN_OR_RETURN(cloned_expr, CloneAst(parent_binding->expr()));
415-
}
416-
ParametricBinding* lambda_struct_binding =
417-
module->Make<ParametricBinding>(
418-
lambda_struct_nd, absl::down_cast<TypeAnnotation*>(cloned_ta),
419-
absl::down_cast<Expr*>(cloned_expr));
420-
struct_parametrics.push_back(lambda_struct_binding);
421-
NameRef* struct_type_parametric_nr = module->Make<NameRef>(
422-
parent_binding->span(), lambda_struct_nd->identifier(),
423-
lambda_struct_nd);
424-
NameRef* struct_instance_parametric_nr = module->Make<NameRef>(
425-
parent_binding->span(), parent_binding->identifier(),
426-
parent_binding->name_def());
427-
struct_type_parametrics.push_back(struct_type_parametric_nr);
428-
if (parent_binding->type_annotation()
429-
->IsAnnotation<GenericTypeAnnotation>()) {
430-
struct_instance_parametrics.push_back(
431-
module->Make<TypeVariableTypeAnnotation>(
432-
struct_instance_parametric_nr));
433-
} else {
434-
struct_instance_parametrics.push_back(struct_instance_parametric_nr);
435-
}
436-
parametric_nds.insert(parent_binding->name_def());
437-
for (const NameRef* original_name_ref : name_refs) {
438-
name_ref_replacements.emplace(
439-
original_name_ref,
440-
module->Make<NameRef>(original_name_ref->span(),
441-
lambda_struct_nd->identifier(),
442-
lambda_struct_nd));
443-
}
409+
XLS_RETURN_IF_ERROR(HandleParentParametric(
410+
module, parent_binding, name_refs, struct_parametric_bindings,
411+
struct_type_parametrics, struct_instance_parametrics,
412+
parametric_nds, name_ref_replacements));
444413
}
445414
}
446415

@@ -449,42 +418,13 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
449418
// already added as a parametric binding.
450419
std::vector<StructMemberNode*> struct_members;
451420
std::vector<std::pair<std::string, Expr*>> struct_instance_members;
421+
absl::flat_hash_set<const NameDef*> seen;
452422
for (const NameDef* original_name_def :
453423
collect_nr.NameDefsDefinedPrior(span.start())) {
454424
if (!parametric_nds.contains(original_name_def)) {
455-
// Create parametric binding with generic type to use for the context
456-
// variable type.
457-
GenericTypeAnnotation* gta =
458-
module->Make<GenericTypeAnnotation>(original_name_def->span());
459-
NameDef* generic_name_def = module->Make<NameDef>(
460-
original_name_def->span(),
461-
absl::Substitute("parametric_type_for_$0",
462-
original_name_def->identifier()),
463-
/*definer=*/gta);
464-
NameRef* generic_name_ref = module->Make<NameRef>(
465-
original_name_def->span(), generic_name_def->identifier(),
466-
generic_name_def);
467-
struct_parametrics.push_back(module->Make<ParametricBinding>(
468-
generic_name_def, gta, /*expr=*/nullptr));
469-
470-
NameDef* struct_member_nd = module->Make<NameDef>(
471-
original_name_def->span(), original_name_def->identifier(),
472-
/*definer=*/nullptr);
473-
TypeVariableTypeAnnotation* tvta =
474-
module->Make<TypeVariableTypeAnnotation>(generic_name_ref,
475-
/*internal=*/true);
476-
StructMemberNode* struct_member = module->Make<StructMemberNode>(
477-
Span::None(), struct_member_nd, Span::None(), tvta);
478-
struct_members.push_back(struct_member);
479-
480-
// Make a name ref that points to the original name def. Add as a member
481-
// to a new struct instance.
482-
NameRef* struct_instance_nr = module->Make<NameRef>(
483-
original_name_def->span(), original_name_def->identifier(),
484-
original_name_def);
485-
struct_instance_members.push_back(std::make_pair(
486-
original_name_def->identifier(), struct_instance_nr));
487-
seen.insert(original_name_def);
425+
HandleCapture(module, original_name_def, struct_parametric_bindings,
426+
struct_type_parametrics, struct_members,
427+
struct_instance_members, seen);
488428
}
489429
}
490430

@@ -494,7 +434,7 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
494434
span.ToString(file_table_)),
495435
/*definer=*/nullptr);
496436
StructDef* full_struct_def =
497-
module->Make<StructDef>(span, struct_nd, struct_parametrics,
437+
module->Make<StructDef>(span, struct_nd, struct_parametric_bindings,
498438
struct_members, /*is_public=*/false);
499439
TypeRefTypeAnnotation* struct_type_annotation =
500440
module->Make<TypeRefTypeAnnotation>(
@@ -611,6 +551,97 @@ class ReplaceLambdaWithInvocation : public AstNodeVisitorWithDefault {
611551
}
612552

613553
private:
554+
absl::Status HandleParentParametric(
555+
Module* module, const ParametricBinding* parent_binding,
556+
absl::flat_hash_set<const NameRef*> name_refs,
557+
std::vector<ParametricBinding*>& struct_parametric_bindings,
558+
std::vector<ExprOrType>& struct_type_parametrics,
559+
std::vector<ExprOrType>& struct_instance_parametrics,
560+
absl::flat_hash_set<const NameDef*>& parametric_nds,
561+
absl::flat_hash_map<const NameRef*, NameRef*>& name_ref_replacements) {
562+
NameDef* lambda_struct_nd = module->Make<NameDef>(
563+
parent_binding->span(), parent_binding->identifier() + "_ls",
564+
/*definer=*/nullptr);
565+
XLS_ASSIGN_OR_RETURN(AstNode * cloned_ta,
566+
CloneAst(parent_binding->type_annotation()));
567+
568+
AstNode* cloned_expr = nullptr;
569+
if (parent_binding->expr() != nullptr) {
570+
XLS_ASSIGN_OR_RETURN(cloned_expr, CloneAst(parent_binding->expr()));
571+
}
572+
ParametricBinding* lambda_struct_binding = module->Make<ParametricBinding>(
573+
lambda_struct_nd, absl::down_cast<TypeAnnotation*>(cloned_ta),
574+
absl::down_cast<Expr*>(cloned_expr));
575+
struct_parametric_bindings.push_back(lambda_struct_binding);
576+
NameRef* struct_type_parametric_nr =
577+
module->Make<NameRef>(parent_binding->span(),
578+
lambda_struct_nd->identifier(), lambda_struct_nd);
579+
NameRef* struct_instance_parametric_nr = module->Make<NameRef>(
580+
parent_binding->span(), parent_binding->identifier(),
581+
parent_binding->name_def());
582+
struct_type_parametrics.push_back(struct_type_parametric_nr);
583+
if (parent_binding->type_annotation()
584+
->IsAnnotation<GenericTypeAnnotation>()) {
585+
struct_instance_parametrics.push_back(
586+
module->Make<TypeVariableTypeAnnotation>(
587+
struct_instance_parametric_nr));
588+
} else {
589+
struct_instance_parametrics.push_back(struct_instance_parametric_nr);
590+
}
591+
parametric_nds.insert(parent_binding->name_def());
592+
for (const NameRef* original_name_ref : name_refs) {
593+
name_ref_replacements.emplace(
594+
original_name_ref,
595+
module->Make<NameRef>(original_name_ref->span(),
596+
lambda_struct_nd->identifier(),
597+
lambda_struct_nd));
598+
}
599+
return absl::OkStatus();
600+
}
601+
602+
void HandleCapture(
603+
Module* module, const NameDef* original_name_def,
604+
std::vector<ParametricBinding*>& struct_parametric_bindings,
605+
std::vector<ExprOrType>& struct_type_parametrics,
606+
std::vector<StructMemberNode*>& struct_members,
607+
std::vector<std::pair<std::string, Expr*>>& struct_instance_members,
608+
absl::flat_hash_set<const NameDef*>& seen) {
609+
// Create parametric binding with generic type to use for the context
610+
// variable type.
611+
GenericTypeAnnotation* gta =
612+
module->Make<GenericTypeAnnotation>(original_name_def->span());
613+
NameDef* generic_name_def =
614+
module->Make<NameDef>(original_name_def->span(),
615+
absl::Substitute("parametric_type_for_$0",
616+
original_name_def->identifier()),
617+
/*definer=*/gta);
618+
NameRef* generic_name_ref =
619+
module->Make<NameRef>(original_name_def->span(),
620+
generic_name_def->identifier(), generic_name_def);
621+
struct_type_parametrics.push_back(generic_name_ref);
622+
struct_parametric_bindings.push_back(module->Make<ParametricBinding>(
623+
generic_name_def, gta, /*expr=*/nullptr));
624+
625+
NameDef* struct_member_nd = module->Make<NameDef>(
626+
original_name_def->span(), original_name_def->identifier(),
627+
/*definer=*/nullptr);
628+
TypeVariableTypeAnnotation* tvta =
629+
module->Make<TypeVariableTypeAnnotation>(generic_name_ref,
630+
/*internal=*/true);
631+
StructMemberNode* struct_member = module->Make<StructMemberNode>(
632+
Span::None(), struct_member_nd, Span::None(), tvta);
633+
struct_members.push_back(struct_member);
634+
635+
// Make a name ref that points to the original name def. Add as a member
636+
// to a new struct instance.
637+
NameRef* struct_instance_nr = module->Make<NameRef>(
638+
original_name_def->span(), original_name_def->identifier(),
639+
original_name_def);
640+
struct_instance_members.push_back(
641+
std::make_pair(original_name_def->identifier(), struct_instance_nr));
642+
seen.insert(original_name_def);
643+
}
644+
614645
const FileTable& file_table_;
615646
};
616647

xls/dslx/ir_convert/testdata/ir_converter_test_MapLambda.ir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@ package test_module
22

33
file_number 0 "test_module.x"
44

5-
fn __test_module__lambda_capture_struct_at_test_module_x_4_12_4_34__call__u32_u32(self: (bits[32]) id=1, i: bits[32] id=2) -> bits[32] {
5+
fn __test_module__lambda_capture_struct_at_test_module_x_4_12_4_34_parametric_type_for_x___call__u32_u32(self: (bits[32]) id=1, i: bits[32] id=2) -> bits[32] {
66
self_x: bits[32] = tuple_index(self, index=0, id=3)
77
ret add.4: bits[32] = add(i, self_x, id=4)
88
}
99

10-
fn __SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_34__call__u32_u32__MAP_0(idx: bits[32] id=8, array_val: bits[32][4] id=9, context_struct: (bits[32]) id=10, input_array: bits[32][4] id=11) -> bits[32][4] {
10+
fn __SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_34_parametric_type_for_x___call__u32_u32__MAP_0(idx: bits[32] id=8, array_val: bits[32][4] id=9, context_struct: (bits[32]) id=10, input_array: bits[32][4] id=11) -> bits[32][4] {
1111
array_index.12: bits[32] = array_index(input_array, indices=[idx], id=12)
12-
map_call_call: bits[32] = invoke(context_struct, array_index.12, to_apply=__test_module__lambda_capture_struct_at_test_module_x_4_12_4_34__call__u32_u32, id=13)
12+
map_call_call: bits[32] = invoke(context_struct, array_index.12, to_apply=__test_module__lambda_capture_struct_at_test_module_x_4_12_4_34_parametric_type_for_x___call__u32_u32, id=13)
1313
ret map_iteration: bits[32][4] = array_update(array_val, map_call_call, indices=[idx], id=14)
1414
}
1515

1616
fn __test_module__add_five(arr: bits[32][4] id=5) -> bits[32][4] {
1717
x: bits[32] = literal(value=5, id=6)
1818
literal.15: bits[32][4] = literal(value=[0, 0, 0, 0], id=15)
1919
tuple.7: (bits[32]) = tuple(x, id=7)
20-
ret map_invocation__call: bits[32][4] = counted_for(literal.15, trip_count=4, stride=1, body=__SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_34__call__u32_u32__MAP_0, invariant_args=[tuple.7, arr], id=16)
20+
ret map_invocation__call: bits[32][4] = counted_for(literal.15, trip_count=4, stride=1, body=__SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_34_parametric_type_for_x___call__u32_u32__MAP_0, invariant_args=[tuple.7, arr], id=16)
2121
}
2222

2323
fn __test_module__main() -> bits[32] {

xls/dslx/ir_convert/testdata/ir_converter_test_MapLambdaImplicitReturn.ir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,22 @@ package test_module
22

33
file_number 0 "test_module.x"
44

5-
fn __test_module__lambda_capture_struct_at_test_module_x_4_12_4_27__call__u32_u32_u32(self: (bits[32]) id=1, i: bits[32] id=2) -> bits[32] {
5+
fn __test_module__lambda_capture_struct_at_test_module_x_4_12_4_27_parametric_type_for_x___call__u32_u32_u32(self: (bits[32]) id=1, i: bits[32] id=2) -> bits[32] {
66
self_x: bits[32] = tuple_index(self, index=0, id=3)
77
ret add.4: bits[32] = add(i, self_x, id=4)
88
}
99

10-
fn __SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_27__call__u32_u32_u32__MAP_0(idx: bits[32] id=8, array_val: bits[32][4] id=9, context_struct: (bits[32]) id=10, input_array: bits[32][4] id=11) -> bits[32][4] {
10+
fn __SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_27_parametric_type_for_x___call__u32_u32_u32__MAP_0(idx: bits[32] id=8, array_val: bits[32][4] id=9, context_struct: (bits[32]) id=10, input_array: bits[32][4] id=11) -> bits[32][4] {
1111
array_index.12: bits[32] = array_index(input_array, indices=[idx], id=12)
12-
map_call_call: bits[32] = invoke(context_struct, array_index.12, to_apply=__test_module__lambda_capture_struct_at_test_module_x_4_12_4_27__call__u32_u32_u32, id=13)
12+
map_call_call: bits[32] = invoke(context_struct, array_index.12, to_apply=__test_module__lambda_capture_struct_at_test_module_x_4_12_4_27_parametric_type_for_x___call__u32_u32_u32, id=13)
1313
ret map_iteration: bits[32][4] = array_update(array_val, map_call_call, indices=[idx], id=14)
1414
}
1515

1616
fn __test_module__add_five(arr: bits[32][4] id=5) -> bits[32][4] {
1717
x: bits[32] = literal(value=5, id=6)
1818
literal.15: bits[32][4] = literal(value=[0, 0, 0, 0], id=15)
1919
tuple.7: (bits[32]) = tuple(x, id=7)
20-
ret map_invocation__call: bits[32][4] = counted_for(literal.15, trip_count=4, stride=1, body=__SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_27__call__u32_u32_u32__MAP_0, invariant_args=[tuple.7, arr], id=16)
20+
ret map_invocation__call: bits[32][4] = counted_for(literal.15, trip_count=4, stride=1, body=__SYNTHETIC_LOOP_BODY___test_module__add_five_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_27_parametric_type_for_x___call__u32_u32_u32__MAP_0, invariant_args=[tuple.7, arr], id=16)
2121
}
2222

2323
fn __test_module__main() -> bits[32] {

xls/dslx/ir_convert/testdata/ir_converter_test_MapLambdaWithParentParametric.ir

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,15 @@ package test_module
22

33
file_number 0 "test_module.x"
44

5-
fn __test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls___call__16_u32_u16_u16(self: (bits[16]) id=1, i: bits[32] id=2) -> bits[16] {
5+
fn __test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls__parametric_type_for_x___call__16_u32_u16_u16(self: (bits[16]) id=1, i: bits[32] id=2) -> bits[16] {
66
bit_slice.3: bits[16] = bit_slice(i, start=0, width=16, id=3)
77
self_x: bits[16] = tuple_index(self, index=0, id=4)
88
ret add.5: bits[16] = add(bit_slice.3, self_x, id=5)
99
}
1010

11-
fn __SYNTHETIC_LOOP_BODY___test_module__add_five__16_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls___call__16_u32_u16_u16__MAP_0(idx: bits[32] id=10, array_val: bits[16][4] id=11, context_struct: (bits[16]) id=12, input_array: bits[32][4] id=13) -> bits[16][4] {
11+
fn __SYNTHETIC_LOOP_BODY___test_module__add_five__16_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls__parametric_type_for_x___call__16_u32_u16_u16__MAP_0(idx: bits[32] id=10, array_val: bits[16][4] id=11, context_struct: (bits[16]) id=12, input_array: bits[32][4] id=13) -> bits[16][4] {
1212
array_index.14: bits[32] = array_index(input_array, indices=[idx], id=14)
13-
map_call_call: bits[16] = invoke(context_struct, array_index.14, to_apply=__test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls___call__16_u32_u16_u16, id=15)
13+
map_call_call: bits[16] = invoke(context_struct, array_index.14, to_apply=__test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls__parametric_type_for_x___call__16_u32_u16_u16, id=15)
1414
ret map_iteration: bits[16][4] = array_update(array_val, map_call_call, indices=[idx], id=16)
1515
}
1616

@@ -19,7 +19,7 @@ fn __test_module__add_five__16(arr: bits[32][4] id=6) -> bits[16][4] {
1919
literal.17: bits[16][4] = literal(value=[0, 0, 0, 0], id=17)
2020
tuple.9: (bits[16]) = tuple(x, id=9)
2121
N: bits[32] = literal(value=16, id=7)
22-
ret map_invocation__call: bits[16][4] = counted_for(literal.17, trip_count=4, stride=1, body=__SYNTHETIC_LOOP_BODY___test_module__add_five__16_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls___call__16_u32_u16_u16__MAP_0, invariant_args=[tuple.9, arr], id=18)
22+
ret map_invocation__call: bits[16][4] = counted_for(literal.17, trip_count=4, stride=1, body=__SYNTHETIC_LOOP_BODY___test_module__add_five__16_CALLING___test_module__lambda_capture_struct_at_test_module_x_4_12_4_36_N_ls__parametric_type_for_x___call__16_u32_u16_u16__MAP_0, invariant_args=[tuple.9, arr], id=18)
2323
}
2424

2525
fn __test_module__main() -> bits[16] {

0 commit comments

Comments
 (0)