Skip to content

Commit 850ff43

Browse files
committed
fix(compiler): preserve proto package info in IR type lookup and validation
Signed-off-by: Peiyang He <peiyang_he@smail.nju.edu.cn>
1 parent f5d72ab commit 850ff43

File tree

4 files changed

+216
-33
lines changed

4 files changed

+216
-33
lines changed

compiler/fory_compiler/frontend/proto/translator.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,7 @@ def _translate_enum(self, proto_enum: ProtoEnum) -> Enum:
143143
line=proto_enum.line,
144144
column=proto_enum.column,
145145
location=self._location(proto_enum.line, proto_enum.column),
146+
package=self.proto_schema.package,
146147
)
147148

148149
def _translate_message(self, proto_msg: ProtoMessage) -> Message:
@@ -173,6 +174,7 @@ def _translate_message(self, proto_msg: ProtoMessage) -> Message:
173174
line=proto_msg.line,
174175
column=proto_msg.column,
175176
location=self._location(proto_msg.line, proto_msg.column),
177+
package=self.proto_schema.package,
176178
)
177179

178180
def _translate_field(self, proto_field: ProtoField) -> Field:
@@ -261,6 +263,7 @@ def _translate_oneof(
261263
line=oneof.line,
262264
column=oneof.column,
263265
location=self._location(oneof.line, oneof.column),
266+
package=self.proto_schema.package,
264267
)
265268

266269
def _translate_oneof_case(self, proto_field: ProtoField) -> Field:

compiler/fory_compiler/ir/ast.py

Lines changed: 36 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,7 @@ class Message:
178178
location: Optional[SourceLocation] = None
179179
id_generated: bool = False
180180
id_source: Optional[str] = None
181+
package: Optional[str] = None # None if not protobuf
181182

182183
def __repr__(self) -> str:
183184
id_str = f" [id={self.type_id}]" if self.type_id is not None else ""
@@ -218,6 +219,7 @@ class Enum:
218219
location: Optional[SourceLocation] = None
219220
id_generated: bool = False
220221
id_source: Optional[str] = None
222+
package: Optional[str] = None # None if not protobuf
221223

222224
def __repr__(self) -> str:
223225
id_str = f" [id={self.type_id}]" if self.type_id is not None else ""
@@ -238,6 +240,7 @@ class Union:
238240
location: Optional[SourceLocation] = None
239241
id_generated: bool = False
240242
id_source: Optional[str] = None
243+
package: Optional[str] = None # None if not protobuf
241244

242245
def __repr__(self) -> str:
243246
id_str = f" [id={self.type_id}]" if self.type_id is not None else ""
@@ -317,16 +320,26 @@ def get_option(self, name: str, default: Optional[str] = None) -> Optional[str]:
317320
return self.options.get(name, default)
318321

319322
def get_type(self, name: str) -> Optional[TypingUnion[Message, Enum, "Union"]]:
320-
"""Look up a type by name, supporting qualified names like Parent.Child."""
321-
# Handle qualified names (e.g., SearchResponse.Result)
322-
if "." in name:
323-
parts = name.split(".")
324-
# Find the top-level type
323+
"""Look up a type by name, supporting qualified names like Parent.Child,
324+
package.TypeName, and .package.TypeName (absolute proto-style)."""
325+
cleaned = name.lstrip(".")
326+
if "." in cleaned:
327+
parts = cleaned.split(".")
328+
# Try the first component as a top-level type name (e.g. "A" in "A.X").
325329
current = self._get_top_level_type(parts[0])
330+
if current is None:
331+
# First component not found as a type, then treat it as a package prefix.
332+
# (e.g. "demo" in "demo.Foo" or "demo" in "demo.A.X").
333+
if len(parts) >= 2:
334+
current = self._get_top_level_type_in_package(parts[1], parts[0])
335+
remaining = parts[2:]
336+
else:
337+
return None
338+
else:
339+
remaining = parts[1:]
326340
if current is None:
327341
return None
328-
# Navigate through nested types
329-
for part in parts[1:]:
342+
for part in remaining:
330343
if isinstance(current, Message):
331344
current = current.get_nested_type(part)
332345
if current is None:
@@ -336,7 +349,7 @@ def get_type(self, name: str) -> Optional[TypingUnion[Message, Enum, "Union"]]:
336349
return None
337350
return current
338351
else:
339-
return self._get_top_level_type(name)
352+
return self._get_top_level_type(cleaned)
340353

341354
def _get_top_level_type(
342355
self, name: str
@@ -353,6 +366,21 @@ def _get_top_level_type(
353366
return message
354367
return None
355368

369+
def _get_top_level_type_in_package(
370+
self, name: str, package: str
371+
) -> Optional[TypingUnion[Message, Enum, "Union"]]:
372+
"""Look up a top-level type by simple name within a specific package."""
373+
for enum in self.enums:
374+
if enum.name == name and enum.package == package:
375+
return enum
376+
for union in self.unions:
377+
if union.name == name and union.package == package:
378+
return union
379+
for message in self.messages:
380+
if message.name == name and message.package == package:
381+
return message
382+
return None
383+
356384
def get_all_types(self) -> List[TypingUnion[Message, Enum, "Union"]]:
357385
"""Get all types including nested types (flattened)."""
358386
result: List[TypingUnion[Message, Enum, "Union"]] = []

compiler/fory_compiler/ir/validator.py

Lines changed: 50 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -86,18 +86,20 @@ def _apply_type_id_defaults(self) -> None:
8686
if t.type_id is not None:
8787
used_ids[t.type_id] = t
8888

89-
def qualify(full_name: str) -> str:
90-
package = self.schema.package_alias or self.schema.package
89+
def qualify(full_name: str, type_package: Optional[str] = None) -> str:
90+
package = type_package or self.schema.package_alias or self.schema.package
9191
if package:
9292
return f"{package}.{full_name}"
9393
return full_name
9494

95-
def resolve_hash_source(full_name: str, alias: Optional[str]) -> str:
95+
def resolve_hash_source(
96+
full_name: str, alias: Optional[str], type_package: Optional[str] = None
97+
) -> str:
9698
if not alias:
97-
return qualify(full_name)
99+
return qualify(full_name, type_package)
98100
if "." in alias:
99101
return alias
100-
package = self.schema.package_alias or self.schema.package
102+
package = type_package or self.schema.package_alias or self.schema.package
101103
if package:
102104
return f"{package}.{alias}"
103105
return alias
@@ -106,7 +108,8 @@ def assign_id(type_def, full_name: str) -> None:
106108
if type_def.type_id is not None:
107109
return
108110
alias = type_def.options.get("alias")
109-
source_name = resolve_hash_source(full_name, alias)
111+
type_package = getattr(type_def, "package", None)
112+
source_name = resolve_hash_source(full_name, alias, type_package)
110113
generated_id = compute_registered_type_id(source_name)
111114
if generated_id in used_ids:
112115
self._error(
@@ -144,28 +147,33 @@ def walk_message(message: Message, parent_path: str = "") -> None:
144147
walk_message(message)
145148

146149
def _check_duplicate_type_names(self) -> None:
150+
# Define (name, package) as key to detect duplication so that same-named types in different packages are not flagged.
151+
# (package is None if not protobuf)
147152
names = {}
148153
for enum in self.schema.enums:
149-
if enum.name in names:
154+
key = (enum.name, getattr(enum, "package", None))
155+
if key in names:
150156
self._error(
151157
f"Duplicate type name: {enum.name}",
152-
enum.location or names[enum.name],
158+
enum.location or names[key],
153159
)
154-
names.setdefault(enum.name, enum.location)
160+
names.setdefault(key, enum.location)
155161
for union in self.schema.unions:
156-
if union.name in names:
162+
key = (union.name, getattr(union, "package", None))
163+
if key in names:
157164
self._error(
158165
f"Duplicate type name: {union.name}",
159-
union.location or names[union.name],
166+
union.location or names[key],
160167
)
161-
names.setdefault(union.name, union.location)
168+
names.setdefault(key, union.location)
162169
for message in self.schema.messages:
163-
if message.name in names:
170+
key = (message.name, getattr(message, "package", None))
171+
if key in names:
164172
self._error(
165173
f"Duplicate type name: {message.name}",
166-
message.location or names[message.name],
174+
message.location or names[key],
167175
)
168-
names.setdefault(message.name, message.location)
176+
names.setdefault(key, message.location)
169177

170178
def _check_duplicate_type_ids(self) -> None:
171179
type_ids = {}
@@ -309,15 +317,12 @@ def _is_message_type(
309317
def _resolve_named_type(
310318
self, name: str, parent_stack: List[Message]
311319
) -> Optional[TypingUnion[Message, Enum, Union]]:
312-
parts = name.split(".")
320+
cleaned = name.lstrip(".")
321+
parts = cleaned.split(".")
313322
if len(parts) > 1:
314-
current = self._find_top_level_type(parts[0])
315-
for part in parts[1:]:
316-
if isinstance(current, Message):
317-
current = current.get_nested_type(part)
318-
else:
319-
return None
320-
return current
323+
# Call schema.get_type to handle both nested-type paths
324+
# (e.g. A.X) and package-prefixed paths (e.g. demo.Foo, demo.A.X).
325+
return self.schema.get_type(name)
321326
for msg in reversed(parent_stack):
322327
nested = msg.get_nested_type(name)
323328
if nested is not None:
@@ -354,8 +359,28 @@ def check_type_ref(
354359
found = True
355360
break
356361

357-
if not found and self.schema.get_type(type_name) is not None:
358-
found = True
362+
if not found:
363+
resolved = self.schema.get_type(type_name)
364+
if resolved is not None:
365+
found = True
366+
# Handle case like test_proto_imported_package_qualified_types_fail properly.
367+
# Reject unqualified type names from a different package.
368+
# This check only happens when both type and schema have explicit package info,
369+
# so that non-protobuf schemas whose package is None are unaffected.
370+
if "." not in type_name:
371+
type_package = getattr(resolved, "package", None)
372+
schema_package = self.schema.package
373+
if (
374+
type_package is not None
375+
and schema_package is not None
376+
and type_package != schema_package
377+
):
378+
self._error(
379+
f"Type '{type_name}' belongs to package "
380+
f"'{type_package}'; use '{type_package}.{type_name}' "
381+
f"to reference it from package '{schema_package}'",
382+
field.location,
383+
)
359384

360385
if not found:
361386
self._error(f"Unknown type '{type_name}'", field.location)

compiler/fory_compiler/tests/test_proto_frontend.py

Lines changed: 127 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,10 @@
2020
from fory_compiler.frontend.proto import ProtoFrontend
2121
from fory_compiler.ir.ast import PrimitiveType
2222
from fory_compiler.ir.types import PrimitiveKind
23+
from fory_compiler.cli import resolve_imports
24+
from fory_compiler.ir.validator import SchemaValidator
25+
import tempfile
26+
from pathlib import Path
2327

2428

2529
def test_proto_type_mapping():
@@ -106,3 +110,126 @@ def test_proto_file_option_enable_auto_type_id():
106110
"""
107111
schema = ProtoFrontend().parse(source)
108112
assert schema.get_option("enable_auto_type_id") is False
113+
114+
115+
def test_proto_nested_qualified_types_pass():
116+
source = """
117+
syntax = "proto3";
118+
package demo;
119+
120+
message A {
121+
message X{}
122+
}
123+
message B {
124+
A.X x1 = 1;
125+
.demo.A.X x2 = 2;
126+
demo.A.X x3 = 3;
127+
}
128+
"""
129+
schema = ProtoFrontend().parse(source)
130+
validator = SchemaValidator(schema)
131+
assert validator.validate()
132+
133+
134+
def test_proto_nested_qualified_types_fail():
135+
source = """
136+
syntax = "proto3";
137+
package demo;
138+
139+
message A {
140+
message X{}
141+
}
142+
message B {
143+
X x = 1;
144+
}
145+
"""
146+
schema = ProtoFrontend().parse(source)
147+
validator = SchemaValidator(schema)
148+
assert not validator.validate()
149+
150+
151+
def test_proto_same_package_qualified_types_pass():
152+
source = """
153+
syntax = "proto3";
154+
package demo;
155+
156+
message Foo {}
157+
158+
message Bar {
159+
demo.Foo foo1 = 1;
160+
.demo.Foo foo2 = 2;
161+
}
162+
"""
163+
schema = ProtoFrontend().parse(source)
164+
validator = SchemaValidator(schema)
165+
assert validator.validate()
166+
167+
168+
def test_proto_imported_package_qualified_types_fail():
169+
with tempfile.TemporaryDirectory() as tmpdir:
170+
tmpdir = Path(tmpdir)
171+
common_proto = tmpdir / "common.proto"
172+
common_proto.write_text(
173+
"""
174+
syntax = "proto3";
175+
package common;
176+
177+
message Address {}
178+
"""
179+
)
180+
main_proto = tmpdir / "main.proto"
181+
main_proto.write_text(
182+
"""
183+
syntax = "proto3";
184+
package main;
185+
import "common.proto";
186+
187+
message User {
188+
Address addr1 = 1;
189+
Address addr2 = 2;
190+
}
191+
"""
192+
)
193+
schema = resolve_imports(main_proto, [tmpdir])
194+
validator = SchemaValidator(schema)
195+
assert not validator.validate()
196+
197+
198+
def test_proto_imported_package_qualified_types_pass():
199+
with tempfile.TemporaryDirectory() as tmpdir:
200+
tmpdir = Path(tmpdir)
201+
common1_proto = tmpdir / "common1.proto"
202+
common1_proto.write_text(
203+
"""
204+
syntax = "proto3";
205+
package common1;
206+
207+
message Address {}
208+
"""
209+
)
210+
common2_proto = tmpdir / "common2.proto"
211+
common2_proto.write_text(
212+
"""
213+
syntax = "proto3";
214+
package common2;
215+
216+
message Address {}
217+
"""
218+
)
219+
main_proto = tmpdir / "main.proto"
220+
main_proto.write_text(
221+
"""
222+
syntax = "proto3";
223+
package main;
224+
import "common1.proto";
225+
import "common2.proto";
226+
227+
message User {
228+
common1.Address addr1 = 1;
229+
.common2.Address addr2 = 2;
230+
}
231+
"""
232+
)
233+
schema = resolve_imports(main_proto, [tmpdir])
234+
validator = SchemaValidator(schema)
235+
assert validator.validate()

0 commit comments

Comments
 (0)