diff --git a/compiler/fory_compiler/cli.py b/compiler/fory_compiler/cli.py index 96325eaa8f..d236e4ae95 100644 --- a/compiler/fory_compiler/cli.py +++ b/compiler/fory_compiler/cli.py @@ -116,6 +116,9 @@ def resolve_imports( visited.add(file_path) + if file_path.suffix == ".proto": + return _resolve_proto_imports(file_path, import_paths, visited, cache) + # Parse the file schema = parse_idl_file(file_path) @@ -167,6 +170,77 @@ def resolve_imports( return merged_schema +def _resolve_proto_imports( + file_path: Path, + import_paths: Optional[List[Path]], + visited: Set[Path], + cache: Dict[Path, Schema], +) -> Schema: + """Proto-specific import resolution.""" + from fory_compiler.frontend.proto import ProtoFrontend + + frontend = ProtoFrontend() + source = file_path.read_text() + proto_schema = frontend.parse_ast(source, str(file_path)) + direct_import_proto_schemas = [] + imported_enums = [] + imported_messages = [] + imported_unions = [] + imported_services = [] + file_packages: Dict[str, Optional[str]] = { + str(file_path): proto_schema.package + } # file -> the package it belongs. + + for imp_path_str in proto_schema.imports: + import_path = resolve_import_path(imp_path_str, file_path, import_paths or []) + if import_path is None: + searched = [str(file_path.parent)] + if import_paths: + searched.extend(str(p) for p in import_paths) + raise ImportError( + f"Import not found: {imp_path_str}\n Searched in: {', '.join(searched)}" + ) + imp_source = import_path.read_text() + imp_proto_ast = frontend.parse_ast(imp_source, str(import_path)) + direct_import_proto_schemas.append(imp_proto_ast) + + # Recursively resolve the imported file + imported_full = resolve_imports( + import_path, import_paths, visited.copy(), cache + ) + imported_enums.extend(imported_full.enums) + imported_messages.extend(imported_full.messages) + imported_unions.extend(imported_full.unions) + imported_services.extend(imported_full.services) + + # Collect file->package mappings from the imported schema. + if imported_full.file_packages: + file_packages.update(imported_full.file_packages) + else: + file_packages[str(import_path)] = imported_full.package + + schema = frontend.parse_with_imports( + source, str(file_path), direct_import_proto_schemas + ) + + merged_schema = Schema( + package=schema.package, + package_alias=schema.package_alias, + imports=schema.imports, + enums=imported_enums + schema.enums, + messages=imported_messages + schema.messages, + unions=imported_unions + schema.unions, + services=imported_services + schema.services, + options=schema.options, + source_file=schema.source_file, + source_format=schema.source_format, + file_packages=file_packages, + ) + + cache[file_path] = copy.deepcopy(merged_schema) + return merged_schema + + def go_package_info(schema: Schema) -> Tuple[Optional[str], str]: """Return (import_path, package_name) for Go.""" go_package = schema.get_option("go_package") diff --git a/compiler/fory_compiler/frontend/proto/__init__.py b/compiler/fory_compiler/frontend/proto/__init__.py index 2d3a30e77f..f72da1bc16 100644 --- a/compiler/fory_compiler/frontend/proto/__init__.py +++ b/compiler/fory_compiler/frontend/proto/__init__.py @@ -18,8 +18,10 @@ """Proto frontend.""" import sys +from typing import List, Optional from fory_compiler.frontend.base import BaseFrontend, FrontendError +from fory_compiler.frontend.proto.ast import ProtoSchema from fory_compiler.frontend.proto.lexer import Lexer, LexerError from fory_compiler.frontend.proto.parser import Parser, ParseError from fory_compiler.frontend.proto.translator import ProtoTranslator @@ -32,15 +34,32 @@ class ProtoFrontend(BaseFrontend): extensions = [".proto"] def parse(self, source: str, filename: str = "") -> Schema: + return self.parse_with_imports(source, filename) + + def parse_ast(self, source: str, filename: str = "") -> ProtoSchema: + """Parse proto source into a proto AST without translating to Fory IR.""" try: lexer = Lexer(source, filename) tokens = lexer.tokenize() parser = Parser(tokens, filename) - proto_schema = parser.parse() + return parser.parse() except (LexerError, ParseError) as exc: raise FrontendError(exc.message, filename, exc.line, exc.column) from exc - translator = ProtoTranslator(proto_schema) + def parse_with_imports( + self, + source: str, + filename: str = "", + direct_import_proto_schemas: Optional[List[ProtoSchema]] = None, + ) -> Schema: + """Parse proto source and translate to Fory IR. + + `direct_import_proto_schemas` supplies the proto ASTs of **directly** + imported files so the translator can resolve cross-file type references + and enforce import-visibility rules. + """ + proto_schema = self.parse_ast(source, filename) + translator = ProtoTranslator(proto_schema, direct_import_proto_schemas) schema = translator.translate() for warning in translator.warnings: diff --git a/compiler/fory_compiler/frontend/proto/translator.py b/compiler/fory_compiler/frontend/proto/translator.py index 5554bb670b..24b5a37203 100644 --- a/compiler/fory_compiler/frontend/proto/translator.py +++ b/compiler/fory_compiler/frontend/proto/translator.py @@ -49,8 +49,24 @@ from fory_compiler.ir.types import PrimitiveKind +class TranslationError(Exception): + """Raised when a type reference cannot be resolved during proto translation.""" + + def __init__(self, message: str, line: int = 0, column: int = 0) -> None: + super().__init__(message) + self.line = line + self.column = column + + class ProtoTranslator: - """Translate Proto AST to Fory IR.""" + """Translate Proto AST to Fory IR. + + Accepts an optional list of *direct* import proto schemas so that type + references are resolved to fully-qualified names and import-visibility is enforced during + translation. + Types from transitively-imported files (not in `direct_import_proto_schemas`) + are absent from the symbol table and will cause a resolution error, matching protoc semantics. + """ TYPE_MAPPING: Dict[str, PrimitiveKind] = { "bool": PrimitiveKind.BOOL, @@ -86,9 +102,141 @@ class ProtoTranslator: "tagged_uint64": PrimitiveKind.TAGGED_UINT64, } - def __init__(self, proto_schema: ProtoSchema): + def __init__( + self, + proto_schema: ProtoSchema, + direct_import_proto_schemas: Optional[List[ProtoSchema]] = None, + ): self.proto_schema = proto_schema + self.direct_import_proto_schemas: List[ProtoSchema] = ( + direct_import_proto_schemas or [] + ) self.warnings: List[str] = [] + # symbol table: fully-qualified name -> (source_file, package). + # Only own file and directly-imported files are included in the symbol table while transitively-imported are excluded. + self._symbol_table: Dict[str, Tuple[str, Optional[str]]] = ( + self._build_symbol_table() + ) + + def _build_symbol_table(self) -> Dict[str, Tuple[str, Optional[str]]]: + table: Dict[str, Tuple[str, Optional[str]]] = {} + own_file = self.proto_schema.source_file or "" + own_pkg = self.proto_schema.package + self._collect_proto_message_qualified_names( + self.proto_schema.messages, own_pkg, "", own_file, table + ) + self._collect_proto_enum_qualified_names( + self.proto_schema.enums, own_pkg, "", own_file, table + ) + for imp_ps in self.direct_import_proto_schemas: + imp_file = imp_ps.source_file or "" + imp_pkg = imp_ps.package + self._collect_proto_message_qualified_names( + imp_ps.messages, imp_pkg, "", imp_file, table + ) + self._collect_proto_enum_qualified_names( + imp_ps.enums, imp_pkg, "", imp_file, table + ) + return table + + def _collect_proto_message_qualified_names( + self, + messages: List[ProtoMessage], + package: Optional[str], + parent_path: str, + source_file: str, + table: Dict[str, Tuple[str, Optional[str]]], + ) -> None: + for msg in messages: + path = f"{parent_path}.{msg.name}" if parent_path else msg.name + qualified_name = f"{package}.{path}" if package else path + table[qualified_name] = (source_file, package) + # Handle nested messages. + self._collect_proto_message_qualified_names( + msg.nested_messages, package, path, source_file, table + ) + # Handle nested enums. + self._collect_proto_enum_qualified_names( + msg.nested_enums, package, path, source_file, table + ) + + def _collect_proto_enum_qualified_names( + self, + enums: List[ProtoEnum], + package: Optional[str], + parent_path: str, + source_file: str, + table: Dict[str, Tuple[str, Optional[str]]], + ) -> None: + for enum in enums: + path = f"{parent_path}.{enum.name}" if parent_path else enum.name + qualified_name = f"{package}.{path}" if package else path + table[qualified_name] = (source_file, package) + + def _resolve_ref( + self, + raw_name: str, + enclosing_path: List[str], + line: int, + column: int, + ) -> str: + """Resolve a proto type-reference string to its fully-qualified name. + + Raise `TranslationError` if the name cannot be resolved or is not + visible (i.e. not in own file or a directly-imported files). + """ + cleaned = raw_name.lstrip(".") + is_absolute = raw_name.startswith(".") + + if is_absolute: + # Absolute names (with leading dot) are looked up directly, + # e.g.: ".com.example.Foo" -> "com.example.Foo". + if cleaned in self._symbol_table: + return cleaned + raise TranslationError(f"Unknown type '{raw_name}'", line, column) + + parts = cleaned.split(".") + own_pkg = self.proto_schema.package + + # Build scope-prefix list from innermost to outermost scope. + # Ref: https://protobuf.dev/programming-guides/proto3/#name-resolution + # e.g., for a reference inside package "com.example", message "Outer", nested + # message "Inner" (enclosing_path = ["Outer", "Inner"]) the prefixes are: + # ["com.example.Outer.Inner", "com.example.Outer", "com.example"]. + scope_prefixes: List[Optional[str]] = [] + for depth in range(len(enclosing_path), -1, -1): + scope_parts = enclosing_path[:depth] + if scope_parts: + inner = ".".join(scope_parts) + scope_prefixes.append(f"{own_pkg}.{inner}" if own_pkg else inner) + else: + scope_prefixes.append(own_pkg) + + for prefix in scope_prefixes: + first_qualified_name = f"{prefix}.{parts[0]}" if prefix else parts[0] + if first_qualified_name not in self._symbol_table: + continue + + if len(parts) == 1: + return first_qualified_name + + current = first_qualified_name + for part in parts[1:]: + nxt = f"{current}.{part}" + if nxt not in self._symbol_table: + raise TranslationError( + f"Nested type '{part}' not found in '{current}'; " + f"cannot resolve '{raw_name}'", + line, + column, + ) + current = nxt + return current + + if cleaned in self._symbol_table: + return cleaned + + raise TranslationError(f"Unknown type '{raw_name}'", line, column) def _location(self, line: int, column: int) -> SourceLocation: return SourceLocation( @@ -99,16 +247,27 @@ def _location(self, line: int, column: int) -> SourceLocation: ) def translate(self) -> Schema: + # Collect the file_packages mapping so the merged schema can do fully-qualified name lookup later + file_packages: Dict[str, Optional[str]] = { + self.proto_schema.source_file or "": self.proto_schema.package + } + for imp_ps in self.direct_import_proto_schemas: + imp_file = imp_ps.source_file or "" + file_packages[imp_file] = imp_ps.package + return Schema( package=self.proto_schema.package, package_alias=None, imports=self._translate_imports(), enums=[self._translate_enum(e) for e in self.proto_schema.enums], - messages=[self._translate_message(m) for m in self.proto_schema.messages], + messages=[ + self._translate_message(m, []) for m in self.proto_schema.messages + ], services=[self._translate_service(s) for s in self.proto_schema.services], options=self._translate_file_options(self.proto_schema.options), source_file=self.proto_schema.source_file, source_format="proto", + file_packages=file_packages, ) def _translate_imports(self) -> List[Import]: @@ -145,21 +304,24 @@ def _translate_enum(self, proto_enum: ProtoEnum) -> Enum: location=self._location(proto_enum.line, proto_enum.column), ) - def _translate_message(self, proto_msg: ProtoMessage) -> Message: + def _translate_message( + self, proto_msg: ProtoMessage, enclosing_path: List[str] + ) -> Message: type_id, options = self._translate_type_options(proto_msg.options) - fields = [self._translate_field(f) for f in proto_msg.fields] + msg_path = enclosing_path + [proto_msg.name] + fields = [self._translate_field(f, msg_path) for f in proto_msg.fields] nested_unions = [] for oneof in proto_msg.oneofs: oneof_type_name = self._oneof_type_name(oneof.name) nested_unions.append( - self._translate_oneof(oneof, oneof_type_name, proto_msg) + self._translate_oneof(oneof, oneof_type_name, proto_msg, msg_path) ) if not oneof.fields: continue union_field = self._translate_oneof_field_reference(oneof, oneof_type_name) fields.append(union_field) nested_messages = [ - self._translate_message(m) for m in proto_msg.nested_messages + self._translate_message(m, msg_path) for m in proto_msg.nested_messages ] nested_enums = [self._translate_enum(e) for e in proto_msg.nested_enums] return Message( @@ -175,8 +337,10 @@ def _translate_message(self, proto_msg: ProtoMessage) -> Message: location=self._location(proto_msg.line, proto_msg.column), ) - def _translate_field(self, proto_field: ProtoField) -> Field: - field_type = self._translate_field_type(proto_field.field_type) + def _translate_field( + self, proto_field: ProtoField, enclosing_path: List[str] + ) -> Field: + field_type = self._translate_field_type(proto_field.field_type, enclosing_path) ref, nullable, options, type_override = self._translate_field_options( proto_field.options ) @@ -251,8 +415,9 @@ def _translate_oneof( oneof: ProtoOneof, oneof_type_name: str, _parent: ProtoMessage, + enclosing_path: List[str], ) -> Union: - fields = [self._translate_oneof_case(f) for f in oneof.fields] + fields = [self._translate_oneof_case(f, enclosing_path) for f in oneof.fields] return Union( name=oneof_type_name, type_id=None, @@ -263,8 +428,10 @@ def _translate_oneof( location=self._location(oneof.line, oneof.column), ) - def _translate_oneof_case(self, proto_field: ProtoField) -> Field: - field_type = self._translate_field_type(proto_field.field_type) + def _translate_oneof_case( + self, proto_field: ProtoField, enclosing_path: List[str] + ) -> Field: + field_type = self._translate_field_type(proto_field.field_type, enclosing_path) ref, _nullable, options, type_override = self._translate_field_options( proto_field.options ) @@ -303,20 +470,35 @@ def _translate_oneof_field_reference( location=self._location(oneof.line, oneof.column), ) - def _translate_field_type(self, proto_type: ProtoType): + def _translate_field_type( + self, proto_type: ProtoType, enclosing_path: List[str] + ) -> FieldType: if proto_type.is_map: - key_type = self._translate_type_name(proto_type.map_key_type or "") - value_type = self._translate_type_name(proto_type.map_value_type or "") + key_type = self._translate_type_name( + proto_type.map_key_type or "", [], proto_type.line, proto_type.column + ) + value_type = self._translate_type_name( + proto_type.map_value_type or "", + enclosing_path, + proto_type.line, + proto_type.column, + ) return MapType( key_type, value_type, location=self._location(proto_type.line, proto_type.column), ) return self._translate_type_name( - proto_type.name, proto_type.line, proto_type.column + proto_type.name, enclosing_path, proto_type.line, proto_type.column ) - def _translate_type_name(self, type_name: str, line: int = 0, column: int = 0): + def _translate_type_name( + self, + type_name: str, + enclosing_path: List[str], + line: int = 0, + column: int = 0, + ) -> FieldType: cleaned = type_name.lstrip(".") if cleaned in self.WELL_KNOWN_TYPES: return PrimitiveType( @@ -328,7 +510,39 @@ def _translate_type_name(self, type_name: str, line: int = 0, column: int = 0): self.TYPE_MAPPING[cleaned], location=self._location(line, column), ) - return NamedType(cleaned, location=self._location(line, column)) + # Resolve user-defined type reference to its fully-qualified name. + try: + qualified_name = self._resolve_ref(type_name, enclosing_path, line, column) + except TranslationError as exc: + from fory_compiler.frontend.base import FrontendError + + raise FrontendError( + str(exc), + self.proto_schema.source_file or "", + exc.line, + exc.column, + ) from exc + # Compute display_name: the name that code generators should use as output type string. + cleaned = type_name.lstrip(".") + _, type_pkg = self._symbol_table.get(qualified_name, (None, None)) + own_pkg = self.proto_schema.package + if type_pkg == own_pkg: + # Same package: use the written reference, minus any redundant package prefix. + if own_pkg and cleaned.startswith(own_pkg + "."): + display_name: Optional[str] = cleaned[len(own_pkg) + 1 :] + else: + display_name = cleaned + else: + # Cross-package: strip the type's package prefix so generators get the type-local path. + if type_pkg and qualified_name.startswith(type_pkg + "."): + display_name = qualified_name[len(type_pkg) + 1 :] + else: + display_name = qualified_name + return NamedType( + qualified_name, + location=self._location(line, column), + display_name=display_name, + ) def _translate_type_options( self, options: Dict[str, object] @@ -403,16 +617,16 @@ def _translate_service(self, proto_service: ProtoService) -> Service: def _translate_rpc_method(self, proto_method: ProtoRpcMethod) -> RpcMethod: # Translate ProtoRpcMethod to RpcMethod _, options = self._translate_type_options(proto_method.options) + req_type = self._translate_type_name( + proto_method.request_type, [], proto_method.line, proto_method.column + ) + resp_type = self._translate_type_name( + proto_method.response_type, [], proto_method.line, proto_method.column + ) return RpcMethod( name=proto_method.name, - request_type=NamedType( - name=proto_method.request_type, - location=self._location(proto_method.line, proto_method.column), - ), - response_type=NamedType( - name=proto_method.response_type, - location=self._location(proto_method.line, proto_method.column), - ), + request_type=req_type, + response_type=resp_type, client_streaming=proto_method.client_streaming, server_streaming=proto_method.server_streaming, options=options, diff --git a/compiler/fory_compiler/generators/base.py b/compiler/fory_compiler/generators/base.py index 379f9c1b54..ca9cadea45 100644 --- a/compiler/fory_compiler/generators/base.py +++ b/compiler/fory_compiler/generators/base.py @@ -229,7 +229,7 @@ def format_idl_type(self, field_type: FieldType) -> str: if isinstance(field_type, PrimitiveType): return field_type.kind.value if isinstance(field_type, NamedType): - return field_type.name + return field_type.display_name or field_type.name if isinstance(field_type, ListType): element = self.format_idl_type(field_type.element_type) return f"list<{element}>" diff --git a/compiler/fory_compiler/generators/cpp.py b/compiler/fory_compiler/generators/cpp.py index a5bc8b418c..e842fed24d 100644 --- a/compiler/fory_compiler/generators/cpp.py +++ b/compiler/fory_compiler/generators/cpp.py @@ -503,7 +503,7 @@ def collect_type_dependencies( if isinstance(field_type, PrimitiveType): return if isinstance(field_type, NamedType): - type_name = field_type.name + type_name = field_type.display_name or field_type.name if self.is_nested_type_reference(type_name, parent_stack): return top_level = type_name.split(".")[0] @@ -598,7 +598,9 @@ def is_message_type( ) -> bool: if not isinstance(field_type, NamedType): return False - resolved = self.resolve_named_type(field_type.name, parent_stack) + resolved = self.resolve_named_type( + field_type.display_name or field_type.name, parent_stack + ) return isinstance(resolved, Message) def is_weak_ref(self, options: dict) -> bool: @@ -620,7 +622,9 @@ def is_union_type( ) -> bool: if not isinstance(field_type, NamedType): return False - resolved = self.resolve_named_type(field_type.name, parent_stack) + resolved = self.resolve_named_type( + field_type.display_name or field_type.name, parent_stack + ) return isinstance(resolved, Union) def is_enum_type( @@ -628,7 +632,9 @@ def is_enum_type( ) -> bool: if not isinstance(field_type, NamedType): return False - resolved = self.resolve_named_type(field_type.name, parent_stack) + resolved = self.resolve_named_type( + field_type.display_name or field_type.name, parent_stack + ) return isinstance(resolved, Enum) def get_field_member_name(self, field: Field) -> str: @@ -1474,7 +1480,8 @@ def generate_namespaced_type( return base_type if isinstance(field_type, NamedType): - type_name = self.resolve_nested_type_name(field_type.name, parent_stack) + local_name = field_type.display_name or field_type.name + type_name = self.resolve_nested_type_name(local_name, parent_stack) named_type = self.schema.get_type(field_type.name) if named_type is not None and self.is_imported_type(named_type): namespace = self._namespace_for_type(named_type) @@ -1652,7 +1659,8 @@ def generate_type( return base_type elif isinstance(field_type, NamedType): - type_name = self.resolve_nested_type_name(field_type.name, parent_stack) + local_name = field_type.display_name or field_type.name + type_name = self.resolve_nested_type_name(local_name, parent_stack) named_type = self.schema.get_type(field_type.name) if named_type is not None and self.is_imported_type(named_type): ns = self._namespace_for_type(named_type) diff --git a/compiler/fory_compiler/generators/go.py b/compiler/fory_compiler/generators/go.py index 0fd9d4b0c1..6b7d1be004 100644 --- a/compiler/fory_compiler/generators/go.py +++ b/compiler/fory_compiler/generators/go.py @@ -674,7 +674,9 @@ def get_union_case_type_id_expr( if isinstance(field.field_type, MapType): return "fory.MAP" if isinstance(field.field_type, NamedType): - type_def = self.resolve_named_type(field.field_type.name, parent_stack) + type_def = self.resolve_named_type( + field.field_type.display_name or field.field_type.name, parent_stack + ) if isinstance(type_def, Enum): if type_def.type_id is None: return "fory.NAMED_ENUM" @@ -970,7 +972,8 @@ def generate_type( return base_type elif isinstance(field_type, NamedType): - type_name = self.resolve_nested_type_name(field_type.name, parent_stack) + local_name = field_type.display_name or field_type.name + type_name = self.resolve_nested_type_name(local_name, parent_stack) named_type = self.schema.get_type(field_type.name) if named_type is not None and self.is_imported_type(named_type): info = self._import_info_for_type(named_type) @@ -978,7 +981,7 @@ def generate_type( getattr(getattr(named_type, "location", None), "file", None) ) if schema is not None: - type_name = self._format_imported_type_name(field_type.name, schema) + type_name = self._format_imported_type_name(local_name, schema) if info is not None: alias, _, _ = info type_name = f"{alias}.{type_name}" @@ -1032,9 +1035,8 @@ def get_union_case_type( ) -> str: """Return the Go type for a union case.""" if isinstance(field.field_type, NamedType): - type_name = self.resolve_nested_type_name( - field.field_type.name, parent_stack - ) + local_name = field.field_type.display_name or field.field_type.name + type_name = self.resolve_nested_type_name(local_name, parent_stack) named_type = self.schema.get_type(field.field_type.name) if named_type is not None and self.is_imported_type(named_type): info = self._import_info_for_type(named_type) @@ -1042,9 +1044,7 @@ def get_union_case_type( getattr(getattr(named_type, "location", None), "file", None) ) if schema is not None: - type_name = self._format_imported_type_name( - field.field_type.name, schema - ) + type_name = self._format_imported_type_name(local_name, schema) if info is not None: alias, _, _ = info type_name = f"{alias}.{type_name}" diff --git a/compiler/fory_compiler/generators/java.py b/compiler/fory_compiler/generators/java.py index e30f4f26d3..854a3cf29e 100644 --- a/compiler/fory_compiler/generators/java.py +++ b/compiler/fory_compiler/generators/java.py @@ -896,7 +896,9 @@ def get_union_case_type_id_expr( if isinstance(field.field_type, MapType): return "Types.MAP" if isinstance(field.field_type, NamedType): - type_def = self.resolve_named_type(field.field_type.name, parent_stack) + type_def = self.resolve_named_type( + field.field_type.display_name or field.field_type.name, parent_stack + ) if isinstance(type_def, Enum): if type_def.type_id is None: return "Types.NAMED_ENUM" @@ -1135,11 +1137,12 @@ def generate_type( elif isinstance(field_type, NamedType): named_type = self.schema.get_type(field_type.name) + local_name = field_type.display_name or field_type.name if named_type is not None and self.is_imported_type(named_type): java_package = self._java_package_for_type(named_type) if java_package: - return f"{java_package}.{field_type.name}" - return field_type.name + return f"{java_package}.{local_name}" + return local_name elif isinstance(field_type, ListType): # Use specialized primitive lists when available, otherwise primitive arrays. diff --git a/compiler/fory_compiler/generators/python.py b/compiler/fory_compiler/generators/python.py index b4c049ff82..5fbe82c540 100644 --- a/compiler/fory_compiler/generators/python.py +++ b/compiler/fory_compiler/generators/python.py @@ -727,7 +727,8 @@ def generate_type( return base_type elif isinstance(field_type, NamedType): - type_name = self.resolve_nested_type_name(field_type.name, parent_stack) + local_name = field_type.display_name or field_type.name + type_name = self.resolve_nested_type_name(local_name, parent_stack) named_type = self.schema.get_type(field_type.name) if named_type is not None and self.is_imported_type(named_type): module = self._module_name_for_type(named_type) @@ -906,7 +907,7 @@ def get_union_case_runtime_check( return f"isinstance({value_expr}, datetime.datetime)" if isinstance(field.field_type, NamedType): type_name = self.resolve_nested_type_name( - field.field_type.name, parent_stack + field.field_type.display_name or field.field_type.name, parent_stack ) return f"isinstance({value_expr}, {type_name})" return None diff --git a/compiler/fory_compiler/generators/rust.py b/compiler/fory_compiler/generators/rust.py index 48fad1f9b9..069f3e128f 100644 --- a/compiler/fory_compiler/generators/rust.py +++ b/compiler/fory_compiler/generators/rust.py @@ -703,12 +703,13 @@ def generate_type( return base_type elif isinstance(field_type, NamedType): - type_name = self.resolve_nested_type_name(field_type.name, parent_stack) + local_name = field_type.display_name or field_type.name + type_name = self.resolve_nested_type_name(local_name, parent_stack) named_type = self.schema.get_type(field_type.name) if named_type is not None and self.is_imported_type(named_type): module = self._module_name_for_type(named_type) if module: - type_name = self._format_imported_type_name(field_type.name, module) + type_name = self._format_imported_type_name(local_name, module) if ref: type_name = f"{pointer_type}<{type_name}>" if nullable: diff --git a/compiler/fory_compiler/ir/ast.py b/compiler/fory_compiler/ir/ast.py index 0c1d750099..62e6bcc96a 100644 --- a/compiler/fory_compiler/ir/ast.py +++ b/compiler/fory_compiler/ir/ast.py @@ -18,7 +18,7 @@ """AST node definitions for FDL.""" from dataclasses import dataclass, field -from typing import List, Optional, Union as TypingUnion +from typing import Dict, List, Optional, Union as TypingUnion from fory_compiler.ir.types import PrimitiveKind @@ -46,10 +46,16 @@ def __repr__(self) -> str: @dataclass class NamedType: - """A reference to a user-defined type (message or enum).""" + """A reference to a user-defined type (message or enum). + + `name` is always the lookup key (fully-qualified name for proto schemas). + `display_name`, when set, is the package-relative name that code enerators should use as the output type string. + Generators fall back to `name` when `display_name` is None (FDL / FBS frontends). + """ name: str location: Optional[SourceLocation] = None + display_name: Optional[str] = None def __repr__(self) -> str: return f"NamedType({self.name})" @@ -302,6 +308,8 @@ class Schema: ) # File-level options (java_package, go_package, etc.) source_file: Optional[str] = None source_format: Optional[str] = None + # Maps absolute file path -> package name. + file_packages: Optional[Dict[str, Optional[str]]] = None def __repr__(self) -> str: opts = f", options={len(self.options)}" if self.options else "" @@ -317,26 +325,85 @@ def get_option(self, name: str, default: Optional[str] = None) -> Optional[str]: return self.options.get(name, default) def get_type(self, name: str) -> Optional[TypingUnion[Message, Enum, "Union"]]: - """Look up a type by name, supporting qualified names like Parent.Child.""" - # Handle qualified names (e.g., SearchResponse.Result) - if "." in name: - parts = name.split(".") - # Find the top-level type - current = self._get_top_level_type(parts[0]) - if current is None: - return None - # Navigate through nested types + """Look up a type by name, supporting qualified names like Parent.Child. + + For proto schemas the translator emits fully-qualified names (e.g. `pkg.Outer.Inner`). + When `file_packages` is set we build a fully-qualified name index on first access so + those names resolve correctly across a merged multi-package schema. + """ + if "." not in name: + return self._get_top_level_type(name) + + # Try stripping the own package prefix first. + if self.package and name.startswith(self.package + "."): + rest = name[len(self.package) + 1 :] + result = self.get_type(rest) + if result is not None: + return result + + # Try dot-separated nested navigation (handles Outer.Inner, A.B.C). + parts = name.split(".") + current = self._get_top_level_type(parts[0]) + if current is not None: for part in parts[1:]: if isinstance(current, Message): current = current.get_nested_type(part) if current is None: - return None + break else: - # Enums don't have nested types - return None - return current - else: - return self._get_top_level_type(name) + current = None + break + if current is not None: + return current + + # For merged proto schemas: look up by full qualified name across all packages. + if self.file_packages is not None: + return self._get_qualified_name_index().get(name) + + return None + + def _get_qualified_name_index( + self, + ) -> Dict[str, TypingUnion[Message, Enum, "Union"]]: + """Build (and cache) a fully-qualified name -> type index for merged proto schemas.""" + cached = getattr(self, "_qualified_name_index_cache", None) + if cached is not None: + return cached + + index: Dict[str, TypingUnion[Message, Enum, "Union"]] = {} + + def pkg_for(location: Optional[SourceLocation]) -> Optional[str]: + if not location or not self.file_packages: + return self.package + return self.file_packages.get(location.file, self.package) + + def add( + type_def: TypingUnion[Message, Enum, "Union"], + path: str, + package: Optional[str], + ) -> None: + qualified_name = f"{package}.{path}" if package else path + index[qualified_name] = type_def + + def walk(msg: Message, package: Optional[str], parent: str) -> None: + path = f"{parent}.{msg.name}" if parent else msg.name + add(msg, path, package) + for e in msg.nested_enums: + add(e, f"{path}.{e.name}", package) + for u in msg.nested_unions: + add(u, f"{path}.{u.name}", package) + for m in msg.nested_messages: + walk(m, package, path) + + for e in self.enums: + add(e, e.name, pkg_for(e.location)) + for u in self.unions: + add(u, u.name, pkg_for(u.location)) + for m in self.messages: + walk(m, pkg_for(m.location), "") + + self._qualified_name_index_cache = index + return index def _get_top_level_type( self, name: str diff --git a/compiler/fory_compiler/ir/validator.py b/compiler/fory_compiler/ir/validator.py index 294657b2f9..3bad7d0cc6 100644 --- a/compiler/fory_compiler/ir/validator.py +++ b/compiler/fory_compiler/ir/validator.py @@ -102,9 +102,18 @@ def resolve_hash_source(full_name: str, alias: Optional[str]) -> str: return f"{package}.{alias}" return alias + own_file = self.schema.source_file or "" + def assign_id(type_def, full_name: str) -> None: if type_def.type_id is not None: return + # For multi-file proto schemas, only assign auto-IDs for the types + # defined in the current file; imported types will get their own + # IDs when their source file is compiled. + if self.schema.file_packages is not None: + type_file = type_def.location.file if type_def.location else None + if type_file != own_file: + return alias = type_def.options.get("alias") source_name = resolve_hash_source(full_name, alias) generated_id = compute_registered_type_id(source_name) @@ -143,29 +152,32 @@ def walk_message(message: Message, parent_path: str = "") -> None: for message in self.schema.messages: walk_message(message) + def _type_qualified_name_key(self, type_def) -> str: + """Return a deduplication key for a type definition. + + For proto merged schemas each type is identified by its fully-qualified name + so that types with the same simple name from different packages are not treated + as duplicates. + """ + if self.schema.file_packages is not None and type_def.location: + pkg = self.schema.file_packages.get(type_def.location.file) + return f"{pkg}.{type_def.name}" if pkg else type_def.name + return type_def.name + def _check_duplicate_type_names(self) -> None: names = {} - for enum in self.schema.enums: - if enum.name in names: - self._error( - f"Duplicate type name: {enum.name}", - enum.location or names[enum.name], - ) - names.setdefault(enum.name, enum.location) - for union in self.schema.unions: - if union.name in names: - self._error( - f"Duplicate type name: {union.name}", - union.location or names[union.name], - ) - names.setdefault(union.name, union.location) - for message in self.schema.messages: - if message.name in names: + for type_def in ( + list(self.schema.enums) + + list(self.schema.unions) + + list(self.schema.messages) + ): + key = self._type_qualified_name_key(type_def) + if key in names: self._error( - f"Duplicate type name: {message.name}", - message.location or names[message.name], + f"Duplicate type name: {type_def.name}", + type_def.location or names[key], ) - names.setdefault(message.name, message.location) + names.setdefault(key, type_def.location) def _check_duplicate_type_ids(self) -> None: type_ids = {} @@ -312,6 +324,10 @@ def _resolve_named_type( parts = name.split(".") if len(parts) > 1: current = self._find_top_level_type(parts[0]) + if current is None: + # Might be a fully-qualified name emitted by the proto translator; delegate to + # Schema.get_type() which handles package-prefix resolution. + return self.schema.get_type(name) for part in parts[1:]: if isinstance(current, Message): current = current.get_nested_type(part) diff --git a/compiler/fory_compiler/tests/test_proto_frontend.py b/compiler/fory_compiler/tests/test_proto_frontend.py index 8d87107a04..a632b77fb8 100644 --- a/compiler/fory_compiler/tests/test_proto_frontend.py +++ b/compiler/fory_compiler/tests/test_proto_frontend.py @@ -17,9 +17,16 @@ """Tests for the proto frontend translation.""" +import pytest +import tempfile +from pathlib import Path + +from fory_compiler.frontend.base import FrontendError from fory_compiler.frontend.proto import ProtoFrontend from fory_compiler.ir.ast import PrimitiveType from fory_compiler.ir.types import PrimitiveKind +from fory_compiler.cli import resolve_imports +from fory_compiler.ir.validator import SchemaValidator def test_proto_type_mapping(): @@ -106,3 +113,292 @@ def test_proto_file_option_enable_auto_type_id(): """ schema = ProtoFrontend().parse(source) assert schema.get_option("enable_auto_type_id") is False + + +def test_proto_nested_qualified_types_pass(): + source = """ + syntax = "proto3"; + package com.example; + + message A { + message B { + message C {} + } + } + message Outer { + A.B.C c1 = 1; + com.example.A.B.C c2 = 2; + .com.example.A.B.C c3 = 3; + } + """ + schema = ProtoFrontend().parse(source) + validator = SchemaValidator(schema) + assert validator.validate() + + +def test_proto_nested_qualified_types_fail(): + # X is only accessible as A.X; pure X is not in scope at B's level. + source = """ + syntax = "proto3"; + package demo; + + message A { + message X{} + } + message B { + X x1 = 1; + X x2 = 2; + } + """ + with pytest.raises(FrontendError): + ProtoFrontend().parse(source) + + +def test_proto_same_package_qualified_types_pass(): + source = """ + syntax = "proto3"; + package com.example; + + message Foo {} + + message Bar { + Foo f1 = 1; + com.example.Foo f2 = 2; + .com.example.Foo f3 = 3; + } + """ + schema = ProtoFrontend().parse(source) + validator = SchemaValidator(schema) + assert validator.validate() + + +def test_proto_imported_package_qualified_types_fail(): + # Pure 'Address' is not visible from package 'main'; must use 'common.Address'. + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + common_proto = tmpdir / "common.proto" + common_proto.write_text( + """ + syntax = "proto3"; + package common; + + message Address {} + """ + ) + main_proto = tmpdir / "main.proto" + main_proto.write_text( + """ + syntax = "proto3"; + package main; + import "common.proto"; + + message User { + Address addr1 = 1; + Address addr2 = 2; + } + """ + ) + with pytest.raises(FrontendError): + resolve_imports(main_proto, [tmpdir]) + + +def test_proto_imported_package_qualified_types_pass(): + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + common1_proto = tmpdir / "common1.proto" + common1_proto.write_text( + """ + syntax = "proto3"; + package com.lib1; + + message Address { + message Country {} + } + """ + ) + common2_proto = tmpdir / "common2.proto" + common2_proto.write_text( + """ + syntax = "proto3"; + package com.lib2; + + message Address {} + """ + ) + main_proto = tmpdir / "main.proto" + main_proto.write_text( + """ + syntax = "proto3"; + package main; + import "common1.proto"; + import "common2.proto"; + + message User { + com.lib1.Address a1 = 1; + .com.lib2.Address a2 = 2; + com.lib1.Address.Country a3 = 3; + } + """ + ) + schema = resolve_imports(main_proto, [tmpdir]) + validator = SchemaValidator(schema) + assert validator.validate() + + +def test_proto_transitive_imports_fail(): + # baz.proto is only transitively imported via bar.proto; main.proto must not reference baz.Foo directly. + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + (tmpdir / "baz.proto").write_text( + """ + syntax = "proto3"; + package baz; + + message Foo {} + """ + ) + (tmpdir / "bar.proto").write_text( + """ + syntax = "proto3"; + package bar; + import "baz.proto"; + + message Bar { baz.Foo foo = 1; } + """ + ) + main_proto = tmpdir / "main.proto" + main_proto.write_text( + """ + syntax = "proto3"; + package main; + import "bar.proto"; + + message User { baz.Foo foo = 1; } + """ + ) + with pytest.raises(FrontendError): + resolve_imports(main_proto, [tmpdir]) + + +def test_proto_same_package_transitive_import_fail(): + # c1.proto and c2.proto share package common; main.proto only imports + # c2.proto. Referencing common.Foo (defined in c1.proto) must fail because + # c1.proto is not directly imported, even though its package name matches + # the directly-imported c2.proto. + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + (tmpdir / "c1.proto").write_text( + """ + syntax = "proto3"; + package common; + + message Foo {} + """ + ) + (tmpdir / "c2.proto").write_text( + """ + syntax = "proto3"; + package common; + import "c1.proto"; + + message Bar {} + """ + ) + main_proto = tmpdir / "main.proto" + main_proto.write_text( + """ + syntax = "proto3"; + package main; + import "c2.proto"; + + message User { common.Foo foo = 1; } + """ + ) + with pytest.raises(FrontendError): + resolve_imports(main_proto, [tmpdir]) + + +def test_proto_local_type_shadows_import_pass(): + # main.proto defines its own Address and also imports common.proto which + # defines another Address. An unqualified field reference Address should + # resolve to the local definition and pass validation. + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + (tmpdir / "common.proto").write_text( + """ + syntax = "proto3"; + package common; + + message Address {} + """ + ) + main_proto = tmpdir / "main.proto" + main_proto.write_text( + """ + syntax = "proto3"; + package main; + import "common.proto"; + + message Address {} + message User { Address addr = 1; } + """ + ) + schema = resolve_imports(main_proto, [tmpdir]) + validator = SchemaValidator(schema) + assert validator.validate() + + +def test_proto_service_rpc_transitive_import_fail(): + # main.proto imports bar.proto which imports baz.proto. Using baz.Foo as + # an RPC request/response type must fail because baz.proto is only + # transitively imported. + with tempfile.TemporaryDirectory() as tmpdir: + tmpdir = Path(tmpdir) + (tmpdir / "baz.proto").write_text( + """ + syntax = "proto3"; + package baz; + + message Foo {} + """ + ) + (tmpdir / "bar.proto").write_text( + """ + syntax = "proto3"; + package bar; + import "baz.proto"; + + message Bar {} + """ + ) + main_proto = tmpdir / "main.proto" + main_proto.write_text( + """ + syntax = "proto3"; + package main; + import "bar.proto"; + + service FooService { rpc GetFoo(baz.Foo) returns (baz.Foo); } + """ + ) + with pytest.raises(FrontendError): + resolve_imports(main_proto, [tmpdir]) + + +def test_proto_same_type_and_package_names_fail(): + # When a message name matches the package name, protoc rejects the relative + # qualified form `demo.demo` because the first `demo` resolves to the + # message (not the package) and that message has no nested `demo` type. + # The absolute form `.demo.demo` is valid. + source = """ + syntax = "proto3"; + package demo; + + message demo {} + + message Ref { + demo.demo d = 1; + } + """ + with pytest.raises(FrontendError): + ProtoFrontend().parse(source) diff --git a/compiler/fory_compiler/tests/test_proto_service.py b/compiler/fory_compiler/tests/test_proto_service.py index 80297beffc..1862611629 100644 --- a/compiler/fory_compiler/tests/test_proto_service.py +++ b/compiler/fory_compiler/tests/test_proto_service.py @@ -17,10 +17,12 @@ """Tests for Proto service parsing.""" +import pytest + +from fory_compiler.frontend.base import FrontendError from fory_compiler.frontend.proto.lexer import Lexer from fory_compiler.frontend.proto.parser import Parser from fory_compiler.frontend.proto.translator import ProtoTranslator -from fory_compiler.ir.validator import SchemaValidator def parse_and_translate(source): @@ -59,8 +61,10 @@ def test_service_parsing(): m1 = service.methods[0] assert m1.name == "SayHello" - assert m1.request_type.name == "Request" - assert m1.response_type.name == "Response" + assert m1.request_type.name == "demo.Request" + assert m1.request_type.display_name == "Request" + assert m1.response_type.name == "demo.Response" + assert m1.response_type.display_name == "Response" assert not m1.client_streaming assert not m1.server_streaming @@ -134,12 +138,8 @@ def test_service_unknown_request_type_fails_validation(): rpc SayHello (UnknownRequest) returns (Response); } """ - schema = parse_and_translate(source) - validator = SchemaValidator(schema) - assert not validator.validate() - assert any( - "Unknown type 'UnknownRequest'" in err.message for err in validator.errors - ) + with pytest.raises(FrontendError, match="Unknown type 'UnknownRequest'"): + parse_and_translate(source) def test_service_unknown_response_type_fails_validation(): @@ -153,7 +153,5 @@ def test_service_unknown_response_type_fails_validation(): rpc SayHello (Request) returns (UnknownReply); } """ - schema = parse_and_translate(source) - validator = SchemaValidator(schema) - assert not validator.validate() - assert any("Unknown type 'UnknownReply'" in err.message for err in validator.errors) + with pytest.raises(FrontendError, match="Unknown type 'UnknownReply'"): + parse_and_translate(source)