Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 14 additions & 2 deletions compiler/fory_compiler/generators/rust.py
Original file line number Diff line number Diff line change
Expand Up @@ -1039,12 +1039,18 @@ def generate_type(
elif isinstance(field_type, ListType):
effective_element_optional = element_optional or field_type.element_optional
effective_element_ref = element_ref or field_type.element_ref
element_pointer_type = pointer_type
if field_type.element_ref:
element_pointer_type = self.get_pointer_type(
field_type.element_ref_options,
field_type.element_ref_options.get("weak_ref") is True,
)
element_type = self.generate_type(
field_type.element_type,
nullable=effective_element_optional,
ref=effective_element_ref,
parent_stack=parent_stack,
pointer_type=pointer_type,
pointer_type=element_pointer_type,
)
list_type = f"::std::vec::Vec<{element_type}>"
if ref:
Expand Down Expand Up @@ -1076,12 +1082,18 @@ def generate_type(
parent_stack=parent_stack,
pointer_type=pointer_type,
)
value_pointer_type = pointer_type
if field_type.value_ref:
value_pointer_type = self.get_pointer_type(
field_type.value_ref_options,
field_type.value_ref_options.get("weak_ref") is True,
)
value_type = self.generate_type(
field_type.value_type,
nullable=False,
ref=field_type.value_ref,
parent_stack=parent_stack,
pointer_type=pointer_type,
pointer_type=value_pointer_type,
)
map_type = f"::std::collections::HashMap<{key_type}, {value_type}>"
if ref:
Expand Down
37 changes: 37 additions & 0 deletions compiler/fory_compiler/tests/test_generated_code.py
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,43 @@ def test_rust_generated_code_can_use_chrono_temporal_types():
assert "::fory::Duration" not in rust_output


def test_rust_nested_container_ref_uses_correct_pointer_type():
schema = parse_fdl(
dedent(
"""
package gen;

message Node {
string value = 1;
}

message Request {
list<list<ref(thread_safe=true) Node>> groups = 1;
map<string, map<string, ref(thread_safe=true) Node>> nodes = 2;
}
"""
)
)

rust_output = render_files(generate_files(schema, RustGenerator))

assert (
"pub groups: ::std::vec::Vec<::std::vec::Vec<::std::sync::Arc<Node>>>,"
in rust_output
)
assert "::std::vec::Vec<::std::vec::Vec<::std::rc::Rc<Node>>>" not in rust_output
assert (
"pub nodes: ::std::collections::HashMap<::std::string::String, "
"::std::collections::HashMap<::std::string::String, ::std::sync::Arc<Node>>>,"
in rust_output
)
assert (
"::std::collections::HashMap<::std::string::String, "
"::std::collections::HashMap<::std::string::String, ::std::rc::Rc<Node>>>"
not in rust_output
)


def test_generated_code_integer_encoding_variants_equivalent():
fdl = dedent(
"""
Expand Down
Loading