Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions compiler/fory_compiler/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ def resolve_imports(

visited.add(file_path)

if file_path.suffix == ".proto":
return _resolve_proto_imports(file_path, import_paths, visited, cache)

# Parse the file
schema = parse_idl_file(file_path)

Expand Down Expand Up @@ -167,6 +170,77 @@ def resolve_imports(
return merged_schema


def _resolve_proto_imports(
file_path: Path,
import_paths: Optional[List[Path]],
visited: Set[Path],
cache: Dict[Path, Schema],
) -> Schema:
"""Proto-specific import resolution."""
from fory_compiler.frontend.proto import ProtoFrontend

frontend = ProtoFrontend()
source = file_path.read_text()
proto_schema = frontend.parse_ast(source, str(file_path))
direct_import_proto_schemas = []
imported_enums = []
imported_messages = []
imported_unions = []
imported_services = []
file_packages: Dict[str, Optional[str]] = {
str(file_path): proto_schema.package
} # file -> the package it belongs.

for imp_path_str in proto_schema.imports:
import_path = resolve_import_path(imp_path_str, file_path, import_paths or [])
if import_path is None:
searched = [str(file_path.parent)]
if import_paths:
searched.extend(str(p) for p in import_paths)
raise ImportError(
f"Import not found: {imp_path_str}\n Searched in: {', '.join(searched)}"
)
imp_source = import_path.read_text()
imp_proto_ast = frontend.parse_ast(imp_source, str(import_path))
direct_import_proto_schemas.append(imp_proto_ast)

# Recursively resolve the imported file
imported_full = resolve_imports(
import_path, import_paths, visited.copy(), cache
)
imported_enums.extend(imported_full.enums)
imported_messages.extend(imported_full.messages)
imported_unions.extend(imported_full.unions)
imported_services.extend(imported_full.services)

# Collect file->package mappings from the imported schema.
if imported_full.file_packages:
file_packages.update(imported_full.file_packages)
else:
file_packages[str(import_path)] = imported_full.package

schema = frontend.parse_with_imports(
source, str(file_path), direct_import_proto_schemas
)

merged_schema = Schema(
package=schema.package,
package_alias=schema.package_alias,
imports=schema.imports,
enums=imported_enums + schema.enums,
messages=imported_messages + schema.messages,
unions=imported_unions + schema.unions,
services=imported_services + schema.services,
options=schema.options,
source_file=schema.source_file,
source_format=schema.source_format,
file_packages=file_packages,
)

cache[file_path] = copy.deepcopy(merged_schema)
return merged_schema


def go_package_info(schema: Schema) -> Tuple[Optional[str], str]:
"""Return (import_path, package_name) for Go."""
go_package = schema.get_option("go_package")
Expand Down
23 changes: 21 additions & 2 deletions compiler/fory_compiler/frontend/proto/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
"""Proto frontend."""

import sys
from typing import List, Optional

from fory_compiler.frontend.base import BaseFrontend, FrontendError
from fory_compiler.frontend.proto.ast import ProtoSchema
from fory_compiler.frontend.proto.lexer import Lexer, LexerError
from fory_compiler.frontend.proto.parser import Parser, ParseError
from fory_compiler.frontend.proto.translator import ProtoTranslator
Expand All @@ -32,15 +34,32 @@ class ProtoFrontend(BaseFrontend):
extensions = [".proto"]

def parse(self, source: str, filename: str = "<input>") -> Schema:
return self.parse_with_imports(source, filename)

def parse_ast(self, source: str, filename: str = "<input>") -> ProtoSchema:
"""Parse proto source into a proto AST without translating to Fory IR."""
try:
lexer = Lexer(source, filename)
tokens = lexer.tokenize()
parser = Parser(tokens, filename)
proto_schema = parser.parse()
return parser.parse()
except (LexerError, ParseError) as exc:
raise FrontendError(exc.message, filename, exc.line, exc.column) from exc

translator = ProtoTranslator(proto_schema)
def parse_with_imports(
self,
source: str,
filename: str = "<input>",
direct_import_proto_schemas: Optional[List[ProtoSchema]] = None,
) -> Schema:
"""Parse proto source and translate to Fory IR.

`direct_import_proto_schemas` supplies the proto ASTs of **directly**
imported files so the translator can resolve cross-file type references
and enforce import-visibility rules.
"""
proto_schema = self.parse_ast(source, filename)
translator = ProtoTranslator(proto_schema, direct_import_proto_schemas)
schema = translator.translate()

for warning in translator.warnings:
Expand Down
Loading
Loading