diff --git a/compiler/fory_compiler/generators/rust.py b/compiler/fory_compiler/generators/rust.py index 6d722bde39..652a9cd6eb 100644 --- a/compiler/fory_compiler/generators/rust.py +++ b/compiler/fory_compiler/generators/rust.py @@ -1039,12 +1039,18 @@ def generate_type( elif isinstance(field_type, ListType): effective_element_optional = element_optional or field_type.element_optional effective_element_ref = element_ref or field_type.element_ref + element_pointer_type = pointer_type + if field_type.element_ref: + element_pointer_type = self.get_pointer_type( + field_type.element_ref_options, + field_type.element_ref_options.get("weak_ref") is True, + ) element_type = self.generate_type( field_type.element_type, nullable=effective_element_optional, ref=effective_element_ref, parent_stack=parent_stack, - pointer_type=pointer_type, + pointer_type=element_pointer_type, ) list_type = f"::std::vec::Vec<{element_type}>" if ref: @@ -1076,12 +1082,18 @@ def generate_type( parent_stack=parent_stack, pointer_type=pointer_type, ) + value_pointer_type = pointer_type + if field_type.value_ref: + value_pointer_type = self.get_pointer_type( + field_type.value_ref_options, + field_type.value_ref_options.get("weak_ref") is True, + ) value_type = self.generate_type( field_type.value_type, nullable=False, ref=field_type.value_ref, parent_stack=parent_stack, - pointer_type=pointer_type, + pointer_type=value_pointer_type, ) map_type = f"::std::collections::HashMap<{key_type}, {value_type}>" if ref: diff --git a/compiler/fory_compiler/tests/test_generated_code.py b/compiler/fory_compiler/tests/test_generated_code.py index 76f2739b55..62f6350f4f 100644 --- a/compiler/fory_compiler/tests/test_generated_code.py +++ b/compiler/fory_compiler/tests/test_generated_code.py @@ -205,6 +205,43 @@ def test_rust_generated_code_can_use_chrono_temporal_types(): assert "::fory::Duration" not in rust_output +def test_rust_nested_container_ref_uses_correct_pointer_type(): + schema = parse_fdl( + dedent( + """ + package gen; + + message Node { + string value = 1; + } + + message Request { + list> groups = 1; + map> nodes = 2; + } + """ + ) + ) + + rust_output = render_files(generate_files(schema, RustGenerator)) + + assert ( + "pub groups: ::std::vec::Vec<::std::vec::Vec<::std::sync::Arc>>," + in rust_output + ) + assert "::std::vec::Vec<::std::vec::Vec<::std::rc::Rc>>" not in rust_output + assert ( + "pub nodes: ::std::collections::HashMap<::std::string::String, " + "::std::collections::HashMap<::std::string::String, ::std::sync::Arc>>," + in rust_output + ) + assert ( + "::std::collections::HashMap<::std::string::String, " + "::std::collections::HashMap<::std::string::String, ::std::rc::Rc>>" + not in rust_output + ) + + def test_generated_code_integer_encoding_variants_equivalent(): fdl = dedent( """