Skip to content

Commit 9c8a7a6

Browse files
committed
refacor(compiler): respect package and name resolution rules for protobuf
1 parent f5d72ab commit 9c8a7a6

File tree

13 files changed

+797
-100
lines changed

13 files changed

+797
-100
lines changed

compiler/fory_compiler/cli.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,9 @@ def resolve_imports(
116116

117117
visited.add(file_path)
118118

119+
if file_path.suffix == ".proto":
120+
return _resolve_proto_imports(file_path, import_paths, visited, cache)
121+
119122
# Parse the file
120123
schema = parse_idl_file(file_path)
121124

@@ -167,6 +170,77 @@ def resolve_imports(
167170
return merged_schema
168171

169172

173+
def _resolve_proto_imports(
174+
file_path: Path,
175+
import_paths: Optional[List[Path]],
176+
visited: Set[Path],
177+
cache: Dict[Path, Schema],
178+
) -> Schema:
179+
"""Proto-specific import resolution."""
180+
from fory_compiler.frontend.proto import ProtoFrontend
181+
182+
frontend = ProtoFrontend()
183+
source = file_path.read_text()
184+
proto_schema = frontend.parse_ast(source, str(file_path))
185+
direct_import_proto_schemas = []
186+
imported_enums = []
187+
imported_messages = []
188+
imported_unions = []
189+
imported_services = []
190+
file_packages: Dict[str, Optional[str]] = {
191+
str(file_path): proto_schema.package
192+
} # file -> the package it belongs.
193+
194+
for imp_path_str in proto_schema.imports:
195+
import_path = resolve_import_path(imp_path_str, file_path, import_paths or [])
196+
if import_path is None:
197+
searched = [str(file_path.parent)]
198+
if import_paths:
199+
searched.extend(str(p) for p in import_paths)
200+
raise ImportError(
201+
f"Import not found: {imp_path_str}\n Searched in: {', '.join(searched)}"
202+
)
203+
imp_source = import_path.read_text()
204+
imp_proto_ast = frontend.parse_ast(imp_source, str(import_path))
205+
direct_import_proto_schemas.append(imp_proto_ast)
206+
207+
# Recursively resolve the imported file
208+
imported_full = resolve_imports(
209+
import_path, import_paths, visited.copy(), cache
210+
)
211+
imported_enums.extend(imported_full.enums)
212+
imported_messages.extend(imported_full.messages)
213+
imported_unions.extend(imported_full.unions)
214+
imported_services.extend(imported_full.services)
215+
216+
# Collect file->package mappings from the imported schema.
217+
if imported_full.file_packages:
218+
file_packages.update(imported_full.file_packages)
219+
else:
220+
file_packages[str(import_path)] = imported_full.package
221+
222+
schema = frontend.parse_with_imports(
223+
source, str(file_path), direct_import_proto_schemas
224+
)
225+
226+
merged_schema = Schema(
227+
package=schema.package,
228+
package_alias=schema.package_alias,
229+
imports=schema.imports,
230+
enums=imported_enums + schema.enums,
231+
messages=imported_messages + schema.messages,
232+
unions=imported_unions + schema.unions,
233+
services=imported_services + schema.services,
234+
options=schema.options,
235+
source_file=schema.source_file,
236+
source_format=schema.source_format,
237+
file_packages=file_packages,
238+
)
239+
240+
cache[file_path] = copy.deepcopy(merged_schema)
241+
return merged_schema
242+
243+
170244
def go_package_info(schema: Schema) -> Tuple[Optional[str], str]:
171245
"""Return (import_path, package_name) for Go."""
172246
go_package = schema.get_option("go_package")

compiler/fory_compiler/frontend/proto/__init__.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
"""Proto frontend."""
1919

2020
import sys
21+
from typing import List, Optional
2122

2223
from fory_compiler.frontend.base import BaseFrontend, FrontendError
24+
from fory_compiler.frontend.proto.ast import ProtoSchema
2325
from fory_compiler.frontend.proto.lexer import Lexer, LexerError
2426
from fory_compiler.frontend.proto.parser import Parser, ParseError
2527
from fory_compiler.frontend.proto.translator import ProtoTranslator
@@ -32,15 +34,32 @@ class ProtoFrontend(BaseFrontend):
3234
extensions = [".proto"]
3335

3436
def parse(self, source: str, filename: str = "<input>") -> Schema:
37+
return self.parse_with_imports(source, filename)
38+
39+
def parse_ast(self, source: str, filename: str = "<input>") -> ProtoSchema:
40+
"""Parse proto source into a proto AST without translating to Fory IR."""
3541
try:
3642
lexer = Lexer(source, filename)
3743
tokens = lexer.tokenize()
3844
parser = Parser(tokens, filename)
39-
proto_schema = parser.parse()
45+
return parser.parse()
4046
except (LexerError, ParseError) as exc:
4147
raise FrontendError(exc.message, filename, exc.line, exc.column) from exc
4248

43-
translator = ProtoTranslator(proto_schema)
49+
def parse_with_imports(
50+
self,
51+
source: str,
52+
filename: str = "<input>",
53+
direct_import_proto_schemas: Optional[List[ProtoSchema]] = None,
54+
) -> Schema:
55+
"""Parse proto source and translate to Fory IR.
56+
57+
`direct_import_proto_schemas` supplies the proto ASTs of **directly**
58+
imported files so the translator can resolve cross-file type references
59+
and enforce import-visibility rules.
60+
"""
61+
proto_schema = self.parse_ast(source, filename)
62+
translator = ProtoTranslator(proto_schema, direct_import_proto_schemas)
4463
schema = translator.translate()
4564

4665
for warning in translator.warnings:

0 commit comments

Comments
 (0)