diff --git a/compiler/fory_compiler/cli.py b/compiler/fory_compiler/cli.py
index 96325eaa8f..d236e4ae95 100644
--- a/compiler/fory_compiler/cli.py
+++ b/compiler/fory_compiler/cli.py
@@ -116,6 +116,9 @@ def resolve_imports(
visited.add(file_path)
+ if file_path.suffix == ".proto":
+ return _resolve_proto_imports(file_path, import_paths, visited, cache)
+
# Parse the file
schema = parse_idl_file(file_path)
@@ -167,6 +170,77 @@ def resolve_imports(
return merged_schema
+def _resolve_proto_imports(
+ file_path: Path,
+ import_paths: Optional[List[Path]],
+ visited: Set[Path],
+ cache: Dict[Path, Schema],
+) -> Schema:
+ """Proto-specific import resolution."""
+ from fory_compiler.frontend.proto import ProtoFrontend
+
+ frontend = ProtoFrontend()
+ source = file_path.read_text()
+ proto_schema = frontend.parse_ast(source, str(file_path))
+ direct_import_proto_schemas = []
+ imported_enums = []
+ imported_messages = []
+ imported_unions = []
+ imported_services = []
+ file_packages: Dict[str, Optional[str]] = {
+ str(file_path): proto_schema.package
+ } # file -> the package it belongs.
+
+ for imp_path_str in proto_schema.imports:
+ import_path = resolve_import_path(imp_path_str, file_path, import_paths or [])
+ if import_path is None:
+ searched = [str(file_path.parent)]
+ if import_paths:
+ searched.extend(str(p) for p in import_paths)
+ raise ImportError(
+ f"Import not found: {imp_path_str}\n Searched in: {', '.join(searched)}"
+ )
+ imp_source = import_path.read_text()
+ imp_proto_ast = frontend.parse_ast(imp_source, str(import_path))
+ direct_import_proto_schemas.append(imp_proto_ast)
+
+ # Recursively resolve the imported file
+ imported_full = resolve_imports(
+ import_path, import_paths, visited.copy(), cache
+ )
+ imported_enums.extend(imported_full.enums)
+ imported_messages.extend(imported_full.messages)
+ imported_unions.extend(imported_full.unions)
+ imported_services.extend(imported_full.services)
+
+ # Collect file->package mappings from the imported schema.
+ if imported_full.file_packages:
+ file_packages.update(imported_full.file_packages)
+ else:
+ file_packages[str(import_path)] = imported_full.package
+
+ schema = frontend.parse_with_imports(
+ source, str(file_path), direct_import_proto_schemas
+ )
+
+ merged_schema = Schema(
+ package=schema.package,
+ package_alias=schema.package_alias,
+ imports=schema.imports,
+ enums=imported_enums + schema.enums,
+ messages=imported_messages + schema.messages,
+ unions=imported_unions + schema.unions,
+ services=imported_services + schema.services,
+ options=schema.options,
+ source_file=schema.source_file,
+ source_format=schema.source_format,
+ file_packages=file_packages,
+ )
+
+ cache[file_path] = copy.deepcopy(merged_schema)
+ return merged_schema
+
+
def go_package_info(schema: Schema) -> Tuple[Optional[str], str]:
"""Return (import_path, package_name) for Go."""
go_package = schema.get_option("go_package")
diff --git a/compiler/fory_compiler/frontend/proto/__init__.py b/compiler/fory_compiler/frontend/proto/__init__.py
index 2d3a30e77f..f72da1bc16 100644
--- a/compiler/fory_compiler/frontend/proto/__init__.py
+++ b/compiler/fory_compiler/frontend/proto/__init__.py
@@ -18,8 +18,10 @@
"""Proto frontend."""
import sys
+from typing import List, Optional
from fory_compiler.frontend.base import BaseFrontend, FrontendError
+from fory_compiler.frontend.proto.ast import ProtoSchema
from fory_compiler.frontend.proto.lexer import Lexer, LexerError
from fory_compiler.frontend.proto.parser import Parser, ParseError
from fory_compiler.frontend.proto.translator import ProtoTranslator
@@ -32,15 +34,32 @@ class ProtoFrontend(BaseFrontend):
extensions = [".proto"]
def parse(self, source: str, filename: str = "") -> Schema:
+ return self.parse_with_imports(source, filename)
+
+ def parse_ast(self, source: str, filename: str = "") -> ProtoSchema:
+ """Parse proto source into a proto AST without translating to Fory IR."""
try:
lexer = Lexer(source, filename)
tokens = lexer.tokenize()
parser = Parser(tokens, filename)
- proto_schema = parser.parse()
+ return parser.parse()
except (LexerError, ParseError) as exc:
raise FrontendError(exc.message, filename, exc.line, exc.column) from exc
- translator = ProtoTranslator(proto_schema)
+ def parse_with_imports(
+ self,
+ source: str,
+ filename: str = "",
+ direct_import_proto_schemas: Optional[List[ProtoSchema]] = None,
+ ) -> Schema:
+ """Parse proto source and translate to Fory IR.
+
+ `direct_import_proto_schemas` supplies the proto ASTs of **directly**
+ imported files so the translator can resolve cross-file type references
+ and enforce import-visibility rules.
+ """
+ proto_schema = self.parse_ast(source, filename)
+ translator = ProtoTranslator(proto_schema, direct_import_proto_schemas)
schema = translator.translate()
for warning in translator.warnings:
diff --git a/compiler/fory_compiler/frontend/proto/translator.py b/compiler/fory_compiler/frontend/proto/translator.py
index 5554bb670b..24b5a37203 100644
--- a/compiler/fory_compiler/frontend/proto/translator.py
+++ b/compiler/fory_compiler/frontend/proto/translator.py
@@ -49,8 +49,24 @@
from fory_compiler.ir.types import PrimitiveKind
+class TranslationError(Exception):
+ """Raised when a type reference cannot be resolved during proto translation."""
+
+ def __init__(self, message: str, line: int = 0, column: int = 0) -> None:
+ super().__init__(message)
+ self.line = line
+ self.column = column
+
+
class ProtoTranslator:
- """Translate Proto AST to Fory IR."""
+ """Translate Proto AST to Fory IR.
+
+ Accepts an optional list of *direct* import proto schemas so that type
+ references are resolved to fully-qualified names and import-visibility is enforced during
+ translation.
+ Types from transitively-imported files (not in `direct_import_proto_schemas`)
+ are absent from the symbol table and will cause a resolution error, matching protoc semantics.
+ """
TYPE_MAPPING: Dict[str, PrimitiveKind] = {
"bool": PrimitiveKind.BOOL,
@@ -86,9 +102,141 @@ class ProtoTranslator:
"tagged_uint64": PrimitiveKind.TAGGED_UINT64,
}
- def __init__(self, proto_schema: ProtoSchema):
+ def __init__(
+ self,
+ proto_schema: ProtoSchema,
+ direct_import_proto_schemas: Optional[List[ProtoSchema]] = None,
+ ):
self.proto_schema = proto_schema
+ self.direct_import_proto_schemas: List[ProtoSchema] = (
+ direct_import_proto_schemas or []
+ )
self.warnings: List[str] = []
+ # symbol table: fully-qualified name -> (source_file, package).
+ # Only own file and directly-imported files are included in the symbol table while transitively-imported are excluded.
+ self._symbol_table: Dict[str, Tuple[str, Optional[str]]] = (
+ self._build_symbol_table()
+ )
+
+ def _build_symbol_table(self) -> Dict[str, Tuple[str, Optional[str]]]:
+ table: Dict[str, Tuple[str, Optional[str]]] = {}
+ own_file = self.proto_schema.source_file or ""
+ own_pkg = self.proto_schema.package
+ self._collect_proto_message_qualified_names(
+ self.proto_schema.messages, own_pkg, "", own_file, table
+ )
+ self._collect_proto_enum_qualified_names(
+ self.proto_schema.enums, own_pkg, "", own_file, table
+ )
+ for imp_ps in self.direct_import_proto_schemas:
+ imp_file = imp_ps.source_file or ""
+ imp_pkg = imp_ps.package
+ self._collect_proto_message_qualified_names(
+ imp_ps.messages, imp_pkg, "", imp_file, table
+ )
+ self._collect_proto_enum_qualified_names(
+ imp_ps.enums, imp_pkg, "", imp_file, table
+ )
+ return table
+
+ def _collect_proto_message_qualified_names(
+ self,
+ messages: List[ProtoMessage],
+ package: Optional[str],
+ parent_path: str,
+ source_file: str,
+ table: Dict[str, Tuple[str, Optional[str]]],
+ ) -> None:
+ for msg in messages:
+ path = f"{parent_path}.{msg.name}" if parent_path else msg.name
+ qualified_name = f"{package}.{path}" if package else path
+ table[qualified_name] = (source_file, package)
+ # Handle nested messages.
+ self._collect_proto_message_qualified_names(
+ msg.nested_messages, package, path, source_file, table
+ )
+ # Handle nested enums.
+ self._collect_proto_enum_qualified_names(
+ msg.nested_enums, package, path, source_file, table
+ )
+
+ def _collect_proto_enum_qualified_names(
+ self,
+ enums: List[ProtoEnum],
+ package: Optional[str],
+ parent_path: str,
+ source_file: str,
+ table: Dict[str, Tuple[str, Optional[str]]],
+ ) -> None:
+ for enum in enums:
+ path = f"{parent_path}.{enum.name}" if parent_path else enum.name
+ qualified_name = f"{package}.{path}" if package else path
+ table[qualified_name] = (source_file, package)
+
+ def _resolve_ref(
+ self,
+ raw_name: str,
+ enclosing_path: List[str],
+ line: int,
+ column: int,
+ ) -> str:
+ """Resolve a proto type-reference string to its fully-qualified name.
+
+ Raise `TranslationError` if the name cannot be resolved or is not
+ visible (i.e. not in own file or a directly-imported files).
+ """
+ cleaned = raw_name.lstrip(".")
+ is_absolute = raw_name.startswith(".")
+
+ if is_absolute:
+ # Absolute names (with leading dot) are looked up directly,
+ # e.g.: ".com.example.Foo" -> "com.example.Foo".
+ if cleaned in self._symbol_table:
+ return cleaned
+ raise TranslationError(f"Unknown type '{raw_name}'", line, column)
+
+ parts = cleaned.split(".")
+ own_pkg = self.proto_schema.package
+
+ # Build scope-prefix list from innermost to outermost scope.
+ # Ref: https://protobuf.dev/programming-guides/proto3/#name-resolution
+ # e.g., for a reference inside package "com.example", message "Outer", nested
+ # message "Inner" (enclosing_path = ["Outer", "Inner"]) the prefixes are:
+ # ["com.example.Outer.Inner", "com.example.Outer", "com.example"].
+ scope_prefixes: List[Optional[str]] = []
+ for depth in range(len(enclosing_path), -1, -1):
+ scope_parts = enclosing_path[:depth]
+ if scope_parts:
+ inner = ".".join(scope_parts)
+ scope_prefixes.append(f"{own_pkg}.{inner}" if own_pkg else inner)
+ else:
+ scope_prefixes.append(own_pkg)
+
+ for prefix in scope_prefixes:
+ first_qualified_name = f"{prefix}.{parts[0]}" if prefix else parts[0]
+ if first_qualified_name not in self._symbol_table:
+ continue
+
+ if len(parts) == 1:
+ return first_qualified_name
+
+ current = first_qualified_name
+ for part in parts[1:]:
+ nxt = f"{current}.{part}"
+ if nxt not in self._symbol_table:
+ raise TranslationError(
+ f"Nested type '{part}' not found in '{current}'; "
+ f"cannot resolve '{raw_name}'",
+ line,
+ column,
+ )
+ current = nxt
+ return current
+
+ if cleaned in self._symbol_table:
+ return cleaned
+
+ raise TranslationError(f"Unknown type '{raw_name}'", line, column)
def _location(self, line: int, column: int) -> SourceLocation:
return SourceLocation(
@@ -99,16 +247,27 @@ def _location(self, line: int, column: int) -> SourceLocation:
)
def translate(self) -> Schema:
+ # Collect the file_packages mapping so the merged schema can do fully-qualified name lookup later
+ file_packages: Dict[str, Optional[str]] = {
+ self.proto_schema.source_file or "": self.proto_schema.package
+ }
+ for imp_ps in self.direct_import_proto_schemas:
+ imp_file = imp_ps.source_file or ""
+ file_packages[imp_file] = imp_ps.package
+
return Schema(
package=self.proto_schema.package,
package_alias=None,
imports=self._translate_imports(),
enums=[self._translate_enum(e) for e in self.proto_schema.enums],
- messages=[self._translate_message(m) for m in self.proto_schema.messages],
+ messages=[
+ self._translate_message(m, []) for m in self.proto_schema.messages
+ ],
services=[self._translate_service(s) for s in self.proto_schema.services],
options=self._translate_file_options(self.proto_schema.options),
source_file=self.proto_schema.source_file,
source_format="proto",
+ file_packages=file_packages,
)
def _translate_imports(self) -> List[Import]:
@@ -145,21 +304,24 @@ def _translate_enum(self, proto_enum: ProtoEnum) -> Enum:
location=self._location(proto_enum.line, proto_enum.column),
)
- def _translate_message(self, proto_msg: ProtoMessage) -> Message:
+ def _translate_message(
+ self, proto_msg: ProtoMessage, enclosing_path: List[str]
+ ) -> Message:
type_id, options = self._translate_type_options(proto_msg.options)
- fields = [self._translate_field(f) for f in proto_msg.fields]
+ msg_path = enclosing_path + [proto_msg.name]
+ fields = [self._translate_field(f, msg_path) for f in proto_msg.fields]
nested_unions = []
for oneof in proto_msg.oneofs:
oneof_type_name = self._oneof_type_name(oneof.name)
nested_unions.append(
- self._translate_oneof(oneof, oneof_type_name, proto_msg)
+ self._translate_oneof(oneof, oneof_type_name, proto_msg, msg_path)
)
if not oneof.fields:
continue
union_field = self._translate_oneof_field_reference(oneof, oneof_type_name)
fields.append(union_field)
nested_messages = [
- self._translate_message(m) for m in proto_msg.nested_messages
+ self._translate_message(m, msg_path) for m in proto_msg.nested_messages
]
nested_enums = [self._translate_enum(e) for e in proto_msg.nested_enums]
return Message(
@@ -175,8 +337,10 @@ def _translate_message(self, proto_msg: ProtoMessage) -> Message:
location=self._location(proto_msg.line, proto_msg.column),
)
- def _translate_field(self, proto_field: ProtoField) -> Field:
- field_type = self._translate_field_type(proto_field.field_type)
+ def _translate_field(
+ self, proto_field: ProtoField, enclosing_path: List[str]
+ ) -> Field:
+ field_type = self._translate_field_type(proto_field.field_type, enclosing_path)
ref, nullable, options, type_override = self._translate_field_options(
proto_field.options
)
@@ -251,8 +415,9 @@ def _translate_oneof(
oneof: ProtoOneof,
oneof_type_name: str,
_parent: ProtoMessage,
+ enclosing_path: List[str],
) -> Union:
- fields = [self._translate_oneof_case(f) for f in oneof.fields]
+ fields = [self._translate_oneof_case(f, enclosing_path) for f in oneof.fields]
return Union(
name=oneof_type_name,
type_id=None,
@@ -263,8 +428,10 @@ def _translate_oneof(
location=self._location(oneof.line, oneof.column),
)
- def _translate_oneof_case(self, proto_field: ProtoField) -> Field:
- field_type = self._translate_field_type(proto_field.field_type)
+ def _translate_oneof_case(
+ self, proto_field: ProtoField, enclosing_path: List[str]
+ ) -> Field:
+ field_type = self._translate_field_type(proto_field.field_type, enclosing_path)
ref, _nullable, options, type_override = self._translate_field_options(
proto_field.options
)
@@ -303,20 +470,35 @@ def _translate_oneof_field_reference(
location=self._location(oneof.line, oneof.column),
)
- def _translate_field_type(self, proto_type: ProtoType):
+ def _translate_field_type(
+ self, proto_type: ProtoType, enclosing_path: List[str]
+ ) -> FieldType:
if proto_type.is_map:
- key_type = self._translate_type_name(proto_type.map_key_type or "")
- value_type = self._translate_type_name(proto_type.map_value_type or "")
+ key_type = self._translate_type_name(
+ proto_type.map_key_type or "", [], proto_type.line, proto_type.column
+ )
+ value_type = self._translate_type_name(
+ proto_type.map_value_type or "",
+ enclosing_path,
+ proto_type.line,
+ proto_type.column,
+ )
return MapType(
key_type,
value_type,
location=self._location(proto_type.line, proto_type.column),
)
return self._translate_type_name(
- proto_type.name, proto_type.line, proto_type.column
+ proto_type.name, enclosing_path, proto_type.line, proto_type.column
)
- def _translate_type_name(self, type_name: str, line: int = 0, column: int = 0):
+ def _translate_type_name(
+ self,
+ type_name: str,
+ enclosing_path: List[str],
+ line: int = 0,
+ column: int = 0,
+ ) -> FieldType:
cleaned = type_name.lstrip(".")
if cleaned in self.WELL_KNOWN_TYPES:
return PrimitiveType(
@@ -328,7 +510,39 @@ def _translate_type_name(self, type_name: str, line: int = 0, column: int = 0):
self.TYPE_MAPPING[cleaned],
location=self._location(line, column),
)
- return NamedType(cleaned, location=self._location(line, column))
+ # Resolve user-defined type reference to its fully-qualified name.
+ try:
+ qualified_name = self._resolve_ref(type_name, enclosing_path, line, column)
+ except TranslationError as exc:
+ from fory_compiler.frontend.base import FrontendError
+
+ raise FrontendError(
+ str(exc),
+ self.proto_schema.source_file or "",
+ exc.line,
+ exc.column,
+ ) from exc
+ # Compute display_name: the name that code generators should use as output type string.
+ cleaned = type_name.lstrip(".")
+ _, type_pkg = self._symbol_table.get(qualified_name, (None, None))
+ own_pkg = self.proto_schema.package
+ if type_pkg == own_pkg:
+ # Same package: use the written reference, minus any redundant package prefix.
+ if own_pkg and cleaned.startswith(own_pkg + "."):
+ display_name: Optional[str] = cleaned[len(own_pkg) + 1 :]
+ else:
+ display_name = cleaned
+ else:
+ # Cross-package: strip the type's package prefix so generators get the type-local path.
+ if type_pkg and qualified_name.startswith(type_pkg + "."):
+ display_name = qualified_name[len(type_pkg) + 1 :]
+ else:
+ display_name = qualified_name
+ return NamedType(
+ qualified_name,
+ location=self._location(line, column),
+ display_name=display_name,
+ )
def _translate_type_options(
self, options: Dict[str, object]
@@ -403,16 +617,16 @@ def _translate_service(self, proto_service: ProtoService) -> Service:
def _translate_rpc_method(self, proto_method: ProtoRpcMethod) -> RpcMethod:
# Translate ProtoRpcMethod to RpcMethod
_, options = self._translate_type_options(proto_method.options)
+ req_type = self._translate_type_name(
+ proto_method.request_type, [], proto_method.line, proto_method.column
+ )
+ resp_type = self._translate_type_name(
+ proto_method.response_type, [], proto_method.line, proto_method.column
+ )
return RpcMethod(
name=proto_method.name,
- request_type=NamedType(
- name=proto_method.request_type,
- location=self._location(proto_method.line, proto_method.column),
- ),
- response_type=NamedType(
- name=proto_method.response_type,
- location=self._location(proto_method.line, proto_method.column),
- ),
+ request_type=req_type,
+ response_type=resp_type,
client_streaming=proto_method.client_streaming,
server_streaming=proto_method.server_streaming,
options=options,
diff --git a/compiler/fory_compiler/generators/base.py b/compiler/fory_compiler/generators/base.py
index 379f9c1b54..ca9cadea45 100644
--- a/compiler/fory_compiler/generators/base.py
+++ b/compiler/fory_compiler/generators/base.py
@@ -229,7 +229,7 @@ def format_idl_type(self, field_type: FieldType) -> str:
if isinstance(field_type, PrimitiveType):
return field_type.kind.value
if isinstance(field_type, NamedType):
- return field_type.name
+ return field_type.display_name or field_type.name
if isinstance(field_type, ListType):
element = self.format_idl_type(field_type.element_type)
return f"list<{element}>"
diff --git a/compiler/fory_compiler/generators/cpp.py b/compiler/fory_compiler/generators/cpp.py
index a5bc8b418c..e842fed24d 100644
--- a/compiler/fory_compiler/generators/cpp.py
+++ b/compiler/fory_compiler/generators/cpp.py
@@ -503,7 +503,7 @@ def collect_type_dependencies(
if isinstance(field_type, PrimitiveType):
return
if isinstance(field_type, NamedType):
- type_name = field_type.name
+ type_name = field_type.display_name or field_type.name
if self.is_nested_type_reference(type_name, parent_stack):
return
top_level = type_name.split(".")[0]
@@ -598,7 +598,9 @@ def is_message_type(
) -> bool:
if not isinstance(field_type, NamedType):
return False
- resolved = self.resolve_named_type(field_type.name, parent_stack)
+ resolved = self.resolve_named_type(
+ field_type.display_name or field_type.name, parent_stack
+ )
return isinstance(resolved, Message)
def is_weak_ref(self, options: dict) -> bool:
@@ -620,7 +622,9 @@ def is_union_type(
) -> bool:
if not isinstance(field_type, NamedType):
return False
- resolved = self.resolve_named_type(field_type.name, parent_stack)
+ resolved = self.resolve_named_type(
+ field_type.display_name or field_type.name, parent_stack
+ )
return isinstance(resolved, Union)
def is_enum_type(
@@ -628,7 +632,9 @@ def is_enum_type(
) -> bool:
if not isinstance(field_type, NamedType):
return False
- resolved = self.resolve_named_type(field_type.name, parent_stack)
+ resolved = self.resolve_named_type(
+ field_type.display_name or field_type.name, parent_stack
+ )
return isinstance(resolved, Enum)
def get_field_member_name(self, field: Field) -> str:
@@ -1474,7 +1480,8 @@ def generate_namespaced_type(
return base_type
if isinstance(field_type, NamedType):
- type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
+ local_name = field_type.display_name or field_type.name
+ type_name = self.resolve_nested_type_name(local_name, parent_stack)
named_type = self.schema.get_type(field_type.name)
if named_type is not None and self.is_imported_type(named_type):
namespace = self._namespace_for_type(named_type)
@@ -1652,7 +1659,8 @@ def generate_type(
return base_type
elif isinstance(field_type, NamedType):
- type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
+ local_name = field_type.display_name or field_type.name
+ type_name = self.resolve_nested_type_name(local_name, parent_stack)
named_type = self.schema.get_type(field_type.name)
if named_type is not None and self.is_imported_type(named_type):
ns = self._namespace_for_type(named_type)
diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py
index 0fd9d4b0c1..6b7d1be004 100644
--- a/compiler/fory_compiler/generators/go.py
+++ b/compiler/fory_compiler/generators/go.py
@@ -674,7 +674,9 @@ def get_union_case_type_id_expr(
if isinstance(field.field_type, MapType):
return "fory.MAP"
if isinstance(field.field_type, NamedType):
- type_def = self.resolve_named_type(field.field_type.name, parent_stack)
+ type_def = self.resolve_named_type(
+ field.field_type.display_name or field.field_type.name, parent_stack
+ )
if isinstance(type_def, Enum):
if type_def.type_id is None:
return "fory.NAMED_ENUM"
@@ -970,7 +972,8 @@ def generate_type(
return base_type
elif isinstance(field_type, NamedType):
- type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
+ local_name = field_type.display_name or field_type.name
+ type_name = self.resolve_nested_type_name(local_name, parent_stack)
named_type = self.schema.get_type(field_type.name)
if named_type is not None and self.is_imported_type(named_type):
info = self._import_info_for_type(named_type)
@@ -978,7 +981,7 @@ def generate_type(
getattr(getattr(named_type, "location", None), "file", None)
)
if schema is not None:
- type_name = self._format_imported_type_name(field_type.name, schema)
+ type_name = self._format_imported_type_name(local_name, schema)
if info is not None:
alias, _, _ = info
type_name = f"{alias}.{type_name}"
@@ -1032,9 +1035,8 @@ def get_union_case_type(
) -> str:
"""Return the Go type for a union case."""
if isinstance(field.field_type, NamedType):
- type_name = self.resolve_nested_type_name(
- field.field_type.name, parent_stack
- )
+ local_name = field.field_type.display_name or field.field_type.name
+ type_name = self.resolve_nested_type_name(local_name, parent_stack)
named_type = self.schema.get_type(field.field_type.name)
if named_type is not None and self.is_imported_type(named_type):
info = self._import_info_for_type(named_type)
@@ -1042,9 +1044,7 @@ def get_union_case_type(
getattr(getattr(named_type, "location", None), "file", None)
)
if schema is not None:
- type_name = self._format_imported_type_name(
- field.field_type.name, schema
- )
+ type_name = self._format_imported_type_name(local_name, schema)
if info is not None:
alias, _, _ = info
type_name = f"{alias}.{type_name}"
diff --git a/compiler/fory_compiler/generators/java.py b/compiler/fory_compiler/generators/java.py
index e30f4f26d3..854a3cf29e 100644
--- a/compiler/fory_compiler/generators/java.py
+++ b/compiler/fory_compiler/generators/java.py
@@ -896,7 +896,9 @@ def get_union_case_type_id_expr(
if isinstance(field.field_type, MapType):
return "Types.MAP"
if isinstance(field.field_type, NamedType):
- type_def = self.resolve_named_type(field.field_type.name, parent_stack)
+ type_def = self.resolve_named_type(
+ field.field_type.display_name or field.field_type.name, parent_stack
+ )
if isinstance(type_def, Enum):
if type_def.type_id is None:
return "Types.NAMED_ENUM"
@@ -1135,11 +1137,12 @@ def generate_type(
elif isinstance(field_type, NamedType):
named_type = self.schema.get_type(field_type.name)
+ local_name = field_type.display_name or field_type.name
if named_type is not None and self.is_imported_type(named_type):
java_package = self._java_package_for_type(named_type)
if java_package:
- return f"{java_package}.{field_type.name}"
- return field_type.name
+ return f"{java_package}.{local_name}"
+ return local_name
elif isinstance(field_type, ListType):
# Use specialized primitive lists when available, otherwise primitive arrays.
diff --git a/compiler/fory_compiler/generators/python.py b/compiler/fory_compiler/generators/python.py
index b4c049ff82..5fbe82c540 100644
--- a/compiler/fory_compiler/generators/python.py
+++ b/compiler/fory_compiler/generators/python.py
@@ -727,7 +727,8 @@ def generate_type(
return base_type
elif isinstance(field_type, NamedType):
- type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
+ local_name = field_type.display_name or field_type.name
+ type_name = self.resolve_nested_type_name(local_name, parent_stack)
named_type = self.schema.get_type(field_type.name)
if named_type is not None and self.is_imported_type(named_type):
module = self._module_name_for_type(named_type)
@@ -906,7 +907,7 @@ def get_union_case_runtime_check(
return f"isinstance({value_expr}, datetime.datetime)"
if isinstance(field.field_type, NamedType):
type_name = self.resolve_nested_type_name(
- field.field_type.name, parent_stack
+ field.field_type.display_name or field.field_type.name, parent_stack
)
return f"isinstance({value_expr}, {type_name})"
return None
diff --git a/compiler/fory_compiler/generators/rust.py b/compiler/fory_compiler/generators/rust.py
index 48fad1f9b9..069f3e128f 100644
--- a/compiler/fory_compiler/generators/rust.py
+++ b/compiler/fory_compiler/generators/rust.py
@@ -703,12 +703,13 @@ def generate_type(
return base_type
elif isinstance(field_type, NamedType):
- type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
+ local_name = field_type.display_name or field_type.name
+ type_name = self.resolve_nested_type_name(local_name, parent_stack)
named_type = self.schema.get_type(field_type.name)
if named_type is not None and self.is_imported_type(named_type):
module = self._module_name_for_type(named_type)
if module:
- type_name = self._format_imported_type_name(field_type.name, module)
+ type_name = self._format_imported_type_name(local_name, module)
if ref:
type_name = f"{pointer_type}<{type_name}>"
if nullable:
diff --git a/compiler/fory_compiler/ir/ast.py b/compiler/fory_compiler/ir/ast.py
index 0c1d750099..62e6bcc96a 100644
--- a/compiler/fory_compiler/ir/ast.py
+++ b/compiler/fory_compiler/ir/ast.py
@@ -18,7 +18,7 @@
"""AST node definitions for FDL."""
from dataclasses import dataclass, field
-from typing import List, Optional, Union as TypingUnion
+from typing import Dict, List, Optional, Union as TypingUnion
from fory_compiler.ir.types import PrimitiveKind
@@ -46,10 +46,16 @@ def __repr__(self) -> str:
@dataclass
class NamedType:
- """A reference to a user-defined type (message or enum)."""
+ """A reference to a user-defined type (message or enum).
+
+ `name` is always the lookup key (fully-qualified name for proto schemas).
+ `display_name`, when set, is the package-relative name that code enerators should use as the output type string.
+ Generators fall back to `name` when `display_name` is None (FDL / FBS frontends).
+ """
name: str
location: Optional[SourceLocation] = None
+ display_name: Optional[str] = None
def __repr__(self) -> str:
return f"NamedType({self.name})"
@@ -302,6 +308,8 @@ class Schema:
) # File-level options (java_package, go_package, etc.)
source_file: Optional[str] = None
source_format: Optional[str] = None
+ # Maps absolute file path -> package name.
+ file_packages: Optional[Dict[str, Optional[str]]] = None
def __repr__(self) -> str:
opts = f", options={len(self.options)}" if self.options else ""
@@ -317,26 +325,85 @@ def get_option(self, name: str, default: Optional[str] = None) -> Optional[str]:
return self.options.get(name, default)
def get_type(self, name: str) -> Optional[TypingUnion[Message, Enum, "Union"]]:
- """Look up a type by name, supporting qualified names like Parent.Child."""
- # Handle qualified names (e.g., SearchResponse.Result)
- if "." in name:
- parts = name.split(".")
- # Find the top-level type
- current = self._get_top_level_type(parts[0])
- if current is None:
- return None
- # Navigate through nested types
+ """Look up a type by name, supporting qualified names like Parent.Child.
+
+ For proto schemas the translator emits fully-qualified names (e.g. `pkg.Outer.Inner`).
+ When `file_packages` is set we build a fully-qualified name index on first access so
+ those names resolve correctly across a merged multi-package schema.
+ """
+ if "." not in name:
+ return self._get_top_level_type(name)
+
+ # Try stripping the own package prefix first.
+ if self.package and name.startswith(self.package + "."):
+ rest = name[len(self.package) + 1 :]
+ result = self.get_type(rest)
+ if result is not None:
+ return result
+
+ # Try dot-separated nested navigation (handles Outer.Inner, A.B.C).
+ parts = name.split(".")
+ current = self._get_top_level_type(parts[0])
+ if current is not None:
for part in parts[1:]:
if isinstance(current, Message):
current = current.get_nested_type(part)
if current is None:
- return None
+ break
else:
- # Enums don't have nested types
- return None
- return current
- else:
- return self._get_top_level_type(name)
+ current = None
+ break
+ if current is not None:
+ return current
+
+ # For merged proto schemas: look up by full qualified name across all packages.
+ if self.file_packages is not None:
+ return self._get_qualified_name_index().get(name)
+
+ return None
+
+ def _get_qualified_name_index(
+ self,
+ ) -> Dict[str, TypingUnion[Message, Enum, "Union"]]:
+ """Build (and cache) a fully-qualified name -> type index for merged proto schemas."""
+ cached = getattr(self, "_qualified_name_index_cache", None)
+ if cached is not None:
+ return cached
+
+ index: Dict[str, TypingUnion[Message, Enum, "Union"]] = {}
+
+ def pkg_for(location: Optional[SourceLocation]) -> Optional[str]:
+ if not location or not self.file_packages:
+ return self.package
+ return self.file_packages.get(location.file, self.package)
+
+ def add(
+ type_def: TypingUnion[Message, Enum, "Union"],
+ path: str,
+ package: Optional[str],
+ ) -> None:
+ qualified_name = f"{package}.{path}" if package else path
+ index[qualified_name] = type_def
+
+ def walk(msg: Message, package: Optional[str], parent: str) -> None:
+ path = f"{parent}.{msg.name}" if parent else msg.name
+ add(msg, path, package)
+ for e in msg.nested_enums:
+ add(e, f"{path}.{e.name}", package)
+ for u in msg.nested_unions:
+ add(u, f"{path}.{u.name}", package)
+ for m in msg.nested_messages:
+ walk(m, package, path)
+
+ for e in self.enums:
+ add(e, e.name, pkg_for(e.location))
+ for u in self.unions:
+ add(u, u.name, pkg_for(u.location))
+ for m in self.messages:
+ walk(m, pkg_for(m.location), "")
+
+ self._qualified_name_index_cache = index
+ return index
def _get_top_level_type(
self, name: str
diff --git a/compiler/fory_compiler/ir/validator.py b/compiler/fory_compiler/ir/validator.py
index 294657b2f9..3bad7d0cc6 100644
--- a/compiler/fory_compiler/ir/validator.py
+++ b/compiler/fory_compiler/ir/validator.py
@@ -102,9 +102,18 @@ def resolve_hash_source(full_name: str, alias: Optional[str]) -> str:
return f"{package}.{alias}"
return alias
+ own_file = self.schema.source_file or ""
+
def assign_id(type_def, full_name: str) -> None:
if type_def.type_id is not None:
return
+ # For multi-file proto schemas, only assign auto-IDs for the types
+ # defined in the current file; imported types will get their own
+ # IDs when their source file is compiled.
+ if self.schema.file_packages is not None:
+ type_file = type_def.location.file if type_def.location else None
+ if type_file != own_file:
+ return
alias = type_def.options.get("alias")
source_name = resolve_hash_source(full_name, alias)
generated_id = compute_registered_type_id(source_name)
@@ -143,29 +152,32 @@ def walk_message(message: Message, parent_path: str = "") -> None:
for message in self.schema.messages:
walk_message(message)
+ def _type_qualified_name_key(self, type_def) -> str:
+ """Return a deduplication key for a type definition.
+
+ For proto merged schemas each type is identified by its fully-qualified name
+ so that types with the same simple name from different packages are not treated
+ as duplicates.
+ """
+ if self.schema.file_packages is not None and type_def.location:
+ pkg = self.schema.file_packages.get(type_def.location.file)
+ return f"{pkg}.{type_def.name}" if pkg else type_def.name
+ return type_def.name
+
def _check_duplicate_type_names(self) -> None:
names = {}
- for enum in self.schema.enums:
- if enum.name in names:
- self._error(
- f"Duplicate type name: {enum.name}",
- enum.location or names[enum.name],
- )
- names.setdefault(enum.name, enum.location)
- for union in self.schema.unions:
- if union.name in names:
- self._error(
- f"Duplicate type name: {union.name}",
- union.location or names[union.name],
- )
- names.setdefault(union.name, union.location)
- for message in self.schema.messages:
- if message.name in names:
+ for type_def in (
+ list(self.schema.enums)
+ + list(self.schema.unions)
+ + list(self.schema.messages)
+ ):
+ key = self._type_qualified_name_key(type_def)
+ if key in names:
self._error(
- f"Duplicate type name: {message.name}",
- message.location or names[message.name],
+ f"Duplicate type name: {type_def.name}",
+ type_def.location or names[key],
)
- names.setdefault(message.name, message.location)
+ names.setdefault(key, type_def.location)
def _check_duplicate_type_ids(self) -> None:
type_ids = {}
@@ -312,6 +324,10 @@ def _resolve_named_type(
parts = name.split(".")
if len(parts) > 1:
current = self._find_top_level_type(parts[0])
+ if current is None:
+ # Might be a fully-qualified name emitted by the proto translator; delegate to
+ # Schema.get_type() which handles package-prefix resolution.
+ return self.schema.get_type(name)
for part in parts[1:]:
if isinstance(current, Message):
current = current.get_nested_type(part)
diff --git a/compiler/fory_compiler/tests/test_proto_frontend.py b/compiler/fory_compiler/tests/test_proto_frontend.py
index 8d87107a04..a632b77fb8 100644
--- a/compiler/fory_compiler/tests/test_proto_frontend.py
+++ b/compiler/fory_compiler/tests/test_proto_frontend.py
@@ -17,9 +17,16 @@
"""Tests for the proto frontend translation."""
+import pytest
+import tempfile
+from pathlib import Path
+
+from fory_compiler.frontend.base import FrontendError
from fory_compiler.frontend.proto import ProtoFrontend
from fory_compiler.ir.ast import PrimitiveType
from fory_compiler.ir.types import PrimitiveKind
+from fory_compiler.cli import resolve_imports
+from fory_compiler.ir.validator import SchemaValidator
def test_proto_type_mapping():
@@ -106,3 +113,292 @@ def test_proto_file_option_enable_auto_type_id():
"""
schema = ProtoFrontend().parse(source)
assert schema.get_option("enable_auto_type_id") is False
+
+
+def test_proto_nested_qualified_types_pass():
+ source = """
+ syntax = "proto3";
+ package com.example;
+
+ message A {
+ message B {
+ message C {}
+ }
+ }
+ message Outer {
+ A.B.C c1 = 1;
+ com.example.A.B.C c2 = 2;
+ .com.example.A.B.C c3 = 3;
+ }
+ """
+ schema = ProtoFrontend().parse(source)
+ validator = SchemaValidator(schema)
+ assert validator.validate()
+
+
+def test_proto_nested_qualified_types_fail():
+ # X is only accessible as A.X; pure X is not in scope at B's level.
+ source = """
+ syntax = "proto3";
+ package demo;
+
+ message A {
+ message X{}
+ }
+ message B {
+ X x1 = 1;
+ X x2 = 2;
+ }
+ """
+ with pytest.raises(FrontendError):
+ ProtoFrontend().parse(source)
+
+
+def test_proto_same_package_qualified_types_pass():
+ source = """
+ syntax = "proto3";
+ package com.example;
+
+ message Foo {}
+
+ message Bar {
+ Foo f1 = 1;
+ com.example.Foo f2 = 2;
+ .com.example.Foo f3 = 3;
+ }
+ """
+ schema = ProtoFrontend().parse(source)
+ validator = SchemaValidator(schema)
+ assert validator.validate()
+
+
+def test_proto_imported_package_qualified_types_fail():
+ # Pure 'Address' is not visible from package 'main'; must use 'common.Address'.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ common_proto = tmpdir / "common.proto"
+ common_proto.write_text(
+ """
+ syntax = "proto3";
+ package common;
+
+ message Address {}
+ """
+ )
+ main_proto = tmpdir / "main.proto"
+ main_proto.write_text(
+ """
+ syntax = "proto3";
+ package main;
+ import "common.proto";
+
+ message User {
+ Address addr1 = 1;
+ Address addr2 = 2;
+ }
+ """
+ )
+ with pytest.raises(FrontendError):
+ resolve_imports(main_proto, [tmpdir])
+
+
+def test_proto_imported_package_qualified_types_pass():
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ common1_proto = tmpdir / "common1.proto"
+ common1_proto.write_text(
+ """
+ syntax = "proto3";
+ package com.lib1;
+
+ message Address {
+ message Country {}
+ }
+ """
+ )
+ common2_proto = tmpdir / "common2.proto"
+ common2_proto.write_text(
+ """
+ syntax = "proto3";
+ package com.lib2;
+
+ message Address {}
+ """
+ )
+ main_proto = tmpdir / "main.proto"
+ main_proto.write_text(
+ """
+ syntax = "proto3";
+ package main;
+ import "common1.proto";
+ import "common2.proto";
+
+ message User {
+ com.lib1.Address a1 = 1;
+ .com.lib2.Address a2 = 2;
+ com.lib1.Address.Country a3 = 3;
+ }
+ """
+ )
+ schema = resolve_imports(main_proto, [tmpdir])
+ validator = SchemaValidator(schema)
+ assert validator.validate()
+
+
+def test_proto_transitive_imports_fail():
+ # baz.proto is only transitively imported via bar.proto; main.proto must not reference baz.Foo directly.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ (tmpdir / "baz.proto").write_text(
+ """
+ syntax = "proto3";
+ package baz;
+
+ message Foo {}
+ """
+ )
+ (tmpdir / "bar.proto").write_text(
+ """
+ syntax = "proto3";
+ package bar;
+ import "baz.proto";
+
+ message Bar { baz.Foo foo = 1; }
+ """
+ )
+ main_proto = tmpdir / "main.proto"
+ main_proto.write_text(
+ """
+ syntax = "proto3";
+ package main;
+ import "bar.proto";
+
+ message User { baz.Foo foo = 1; }
+ """
+ )
+ with pytest.raises(FrontendError):
+ resolve_imports(main_proto, [tmpdir])
+
+
+def test_proto_same_package_transitive_import_fail():
+ # c1.proto and c2.proto share package common; main.proto only imports
+ # c2.proto. Referencing common.Foo (defined in c1.proto) must fail because
+ # c1.proto is not directly imported, even though its package name matches
+ # the directly-imported c2.proto.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ (tmpdir / "c1.proto").write_text(
+ """
+ syntax = "proto3";
+ package common;
+
+ message Foo {}
+ """
+ )
+ (tmpdir / "c2.proto").write_text(
+ """
+ syntax = "proto3";
+ package common;
+ import "c1.proto";
+
+ message Bar {}
+ """
+ )
+ main_proto = tmpdir / "main.proto"
+ main_proto.write_text(
+ """
+ syntax = "proto3";
+ package main;
+ import "c2.proto";
+
+ message User { common.Foo foo = 1; }
+ """
+ )
+ with pytest.raises(FrontendError):
+ resolve_imports(main_proto, [tmpdir])
+
+
+def test_proto_local_type_shadows_import_pass():
+ # main.proto defines its own Address and also imports common.proto which
+ # defines another Address. An unqualified field reference Address should
+ # resolve to the local definition and pass validation.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ (tmpdir / "common.proto").write_text(
+ """
+ syntax = "proto3";
+ package common;
+
+ message Address {}
+ """
+ )
+ main_proto = tmpdir / "main.proto"
+ main_proto.write_text(
+ """
+ syntax = "proto3";
+ package main;
+ import "common.proto";
+
+ message Address {}
+ message User { Address addr = 1; }
+ """
+ )
+ schema = resolve_imports(main_proto, [tmpdir])
+ validator = SchemaValidator(schema)
+ assert validator.validate()
+
+
+def test_proto_service_rpc_transitive_import_fail():
+ # main.proto imports bar.proto which imports baz.proto. Using baz.Foo as
+ # an RPC request/response type must fail because baz.proto is only
+ # transitively imported.
+ with tempfile.TemporaryDirectory() as tmpdir:
+ tmpdir = Path(tmpdir)
+ (tmpdir / "baz.proto").write_text(
+ """
+ syntax = "proto3";
+ package baz;
+
+ message Foo {}
+ """
+ )
+ (tmpdir / "bar.proto").write_text(
+ """
+ syntax = "proto3";
+ package bar;
+ import "baz.proto";
+
+ message Bar {}
+ """
+ )
+ main_proto = tmpdir / "main.proto"
+ main_proto.write_text(
+ """
+ syntax = "proto3";
+ package main;
+ import "bar.proto";
+
+ service FooService { rpc GetFoo(baz.Foo) returns (baz.Foo); }
+ """
+ )
+ with pytest.raises(FrontendError):
+ resolve_imports(main_proto, [tmpdir])
+
+
+def test_proto_same_type_and_package_names_fail():
+ # When a message name matches the package name, protoc rejects the relative
+ # qualified form `demo.demo` because the first `demo` resolves to the
+ # message (not the package) and that message has no nested `demo` type.
+ # The absolute form `.demo.demo` is valid.
+ source = """
+ syntax = "proto3";
+ package demo;
+
+ message demo {}
+
+ message Ref {
+ demo.demo d = 1;
+ }
+ """
+ with pytest.raises(FrontendError):
+ ProtoFrontend().parse(source)
diff --git a/compiler/fory_compiler/tests/test_proto_service.py b/compiler/fory_compiler/tests/test_proto_service.py
index 80297beffc..1862611629 100644
--- a/compiler/fory_compiler/tests/test_proto_service.py
+++ b/compiler/fory_compiler/tests/test_proto_service.py
@@ -17,10 +17,12 @@
"""Tests for Proto service parsing."""
+import pytest
+
+from fory_compiler.frontend.base import FrontendError
from fory_compiler.frontend.proto.lexer import Lexer
from fory_compiler.frontend.proto.parser import Parser
from fory_compiler.frontend.proto.translator import ProtoTranslator
-from fory_compiler.ir.validator import SchemaValidator
def parse_and_translate(source):
@@ -59,8 +61,10 @@ def test_service_parsing():
m1 = service.methods[0]
assert m1.name == "SayHello"
- assert m1.request_type.name == "Request"
- assert m1.response_type.name == "Response"
+ assert m1.request_type.name == "demo.Request"
+ assert m1.request_type.display_name == "Request"
+ assert m1.response_type.name == "demo.Response"
+ assert m1.response_type.display_name == "Response"
assert not m1.client_streaming
assert not m1.server_streaming
@@ -134,12 +138,8 @@ def test_service_unknown_request_type_fails_validation():
rpc SayHello (UnknownRequest) returns (Response);
}
"""
- schema = parse_and_translate(source)
- validator = SchemaValidator(schema)
- assert not validator.validate()
- assert any(
- "Unknown type 'UnknownRequest'" in err.message for err in validator.errors
- )
+ with pytest.raises(FrontendError, match="Unknown type 'UnknownRequest'"):
+ parse_and_translate(source)
def test_service_unknown_response_type_fails_validation():
@@ -153,7 +153,5 @@ def test_service_unknown_response_type_fails_validation():
rpc SayHello (Request) returns (UnknownReply);
}
"""
- schema = parse_and_translate(source)
- validator = SchemaValidator(schema)
- assert not validator.validate()
- assert any("Unknown type 'UnknownReply'" in err.message for err in validator.errors)
+ with pytest.raises(FrontendError, match="Unknown type 'UnknownReply'"):
+ parse_and_translate(source)