Skip to content

Commit 3cd7925

Browse files
committed
fix compiler bugs
1 parent 46b53e6 commit 3cd7925

15 files changed

Lines changed: 422 additions & 96 deletions

File tree

compiler/README.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -345,7 +345,7 @@ Generates structs with:
345345

346346
```go
347347
type Cat struct {
348-
Friend *Dog `fory:"trackRef"`
348+
Friend *Dog `fory:"ref"`
349349
Name *string `fory:"nullable"`
350350
Tags []string
351351
}
@@ -362,7 +362,7 @@ Generates structs with:
362362
```rust
363363
#[derive(ForyObject, Debug, Clone, PartialEq, Default)]
364364
pub struct Cat {
365-
pub friend: Rc<Dog>,
365+
pub friend: Arc<Dog>,
366366
#[fory(nullable = true)]
367367
pub name: Option<String>,
368368
pub tags: Vec<String>,

compiler/fory_compiler/generators/cpp.py

Lines changed: 52 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -165,7 +165,14 @@ def generate_header(self) -> GeneratedFile:
165165
def collect_message_includes(self, message: Message, includes: Set[str]):
166166
"""Collect includes for a message and its nested types recursively."""
167167
for field in message.fields:
168-
self.collect_includes(field.field_type, field.optional, field.ref, includes)
168+
self.collect_includes(
169+
field.field_type,
170+
field.optional,
171+
field.ref,
172+
includes,
173+
field.element_optional,
174+
field.element_ref,
175+
)
169176
for nested_msg in message.nested_messages:
170177
self.collect_message_includes(nested_msg, includes)
171178

@@ -236,6 +243,8 @@ def generate_message(
236243
field.field_type,
237244
field.optional,
238245
field.ref,
246+
field.element_optional,
247+
field.element_ref,
239248
lineage,
240249
)
241250
field_name = self.to_snake_case(field.name)
@@ -258,8 +267,11 @@ def generate_message(
258267
lines.append("};")
259268

260269
# FORY_STRUCT macro (must stay in type namespace for ADL)
261-
field_names = ", ".join(self.to_snake_case(f.name) for f in message.fields)
262-
lines.append(f"FORY_STRUCT({type_name}, {field_names});")
270+
if message.fields:
271+
field_names = ", ".join(self.to_snake_case(f.name) for f in message.fields)
272+
lines.append(f"FORY_STRUCT({type_name}, {field_names});")
273+
else:
274+
lines.append(f"FORY_STRUCT({type_name});")
263275

264276
return lines
265277

@@ -309,6 +321,8 @@ def generate_type(
309321
field_type: FieldType,
310322
nullable: bool = False,
311323
ref: bool = False,
324+
element_optional: bool = False,
325+
element_ref: bool = False,
312326
parent_stack: Optional[List[Message]] = None,
313327
) -> str:
314328
"""Generate C++ type string."""
@@ -321,25 +335,40 @@ def generate_type(
321335
elif isinstance(field_type, NamedType):
322336
type_name = self.resolve_nested_type_name(field_type.name, parent_stack)
323337
if ref:
324-
return f"std::shared_ptr<{type_name}>"
338+
type_name = f"std::shared_ptr<{type_name}>"
325339
if nullable:
326-
return f"std::optional<{type_name}>"
340+
type_name = f"std::optional<{type_name}>"
327341
return type_name
328342

329343
elif isinstance(field_type, ListType):
330344
element_type = self.generate_type(
331-
field_type.element_type, False, False, parent_stack
345+
field_type.element_type,
346+
element_optional,
347+
element_ref,
348+
False,
349+
False,
350+
parent_stack,
332351
)
333-
return f"std::vector<{element_type}>"
352+
list_type = f"std::vector<{element_type}>"
353+
if ref:
354+
list_type = f"std::shared_ptr<{list_type}>"
355+
if nullable:
356+
list_type = f"std::optional<{list_type}>"
357+
return list_type
334358

335359
elif isinstance(field_type, MapType):
336360
key_type = self.generate_type(
337-
field_type.key_type, False, False, parent_stack
361+
field_type.key_type, False, False, False, False, parent_stack
338362
)
339363
value_type = self.generate_type(
340-
field_type.value_type, False, False, parent_stack
364+
field_type.value_type, False, False, False, False, parent_stack
341365
)
342-
return f"std::map<{key_type}, {value_type}>"
366+
map_type = f"std::map<{key_type}, {value_type}>"
367+
if ref:
368+
map_type = f"std::shared_ptr<{map_type}>"
369+
if nullable:
370+
map_type = f"std::optional<{map_type}>"
371+
return map_type
343372

344373
return "void*"
345374

@@ -363,7 +392,13 @@ def resolve_nested_type_name(
363392
return type_name
364393

365394
def collect_includes(
366-
self, field_type: FieldType, nullable: bool, ref: bool, includes: Set[str]
395+
self,
396+
field_type: FieldType,
397+
nullable: bool,
398+
ref: bool,
399+
includes: Set[str],
400+
element_optional: bool = False,
401+
element_ref: bool = False,
367402
):
368403
"""Collect required includes for a field type."""
369404
if nullable:
@@ -381,7 +416,12 @@ def collect_includes(
381416

382417
elif isinstance(field_type, ListType):
383418
includes.add("<vector>")
384-
self.collect_includes(field_type.element_type, False, False, includes)
419+
self.collect_includes(
420+
field_type.element_type,
421+
element_optional,
422+
element_ref,
423+
includes,
424+
)
385425

386426
elif isinstance(field_type, MapType):
387427
includes.add("<map>")

compiler/fory_compiler/generators/go.py

Lines changed: 43 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,8 @@ def generate_field(
244244
field.field_type,
245245
field.optional,
246246
field.ref,
247+
field.element_optional,
248+
field.element_ref,
247249
parent_stack,
248250
)
249251
field_name = self.to_pascal_case(
@@ -252,10 +254,29 @@ def generate_field(
252254

253255
# Build fory tag
254256
tags = []
257+
is_list = isinstance(field.field_type, ListType)
258+
nullable_tag: Optional[bool] = None
259+
ref_tag: Optional[bool] = None
260+
255261
if field.optional:
256-
tags.append("nullable")
262+
nullable_tag = True
263+
elif is_list and (field.ref or field.element_optional or field.element_ref):
264+
nullable_tag = False
265+
257266
if field.ref:
258-
tags.append("trackRef")
267+
ref_tag = True
268+
elif is_list and field.element_ref:
269+
ref_tag = False
270+
271+
if nullable_tag is True:
272+
tags.append("nullable")
273+
elif nullable_tag is False:
274+
tags.append("nullable=false")
275+
276+
if ref_tag is True:
277+
tags.append("ref")
278+
elif ref_tag is False:
279+
tags.append("ref=false")
259280

260281
if tags:
261282
tag_str = ",".join(tags)
@@ -270,6 +291,8 @@ def generate_type(
270291
field_type: FieldType,
271292
nullable: bool = False,
272293
ref: bool = False,
294+
element_optional: bool = False,
295+
element_ref: bool = False,
273296
parent_stack: Optional[List[Message]] = None,
274297
) -> str:
275298
"""Generate Go type string."""
@@ -287,16 +310,31 @@ def generate_type(
287310

288311
elif isinstance(field_type, ListType):
289312
element_type = self.generate_type(
290-
field_type.element_type, False, False, parent_stack
313+
field_type.element_type,
314+
element_optional,
315+
element_ref,
316+
False,
317+
False,
318+
parent_stack,
291319
)
292320
return f"[]{element_type}"
293321

294322
elif isinstance(field_type, MapType):
295323
key_type = self.generate_type(
296-
field_type.key_type, False, False, parent_stack
324+
field_type.key_type,
325+
False,
326+
False,
327+
False,
328+
False,
329+
parent_stack,
297330
)
298331
value_type = self.generate_type(
299-
field_type.value_type, False, False, parent_stack
332+
field_type.value_type,
333+
False,
334+
False,
335+
False,
336+
False,
337+
parent_stack,
300338
)
301339
return f"map[{key_type}]{value_type}"
302340

compiler/fory_compiler/generators/java.py

Lines changed: 56 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -322,7 +322,7 @@ def generate_outer_class_file(self, outer_classname: str) -> GeneratedFile:
322322
def collect_message_imports(self, message: Message, imports: Set[str]):
323323
"""Collect imports for a message and all its nested types recursively."""
324324
for field in message.fields:
325-
self.collect_imports(field.field_type, imports)
325+
self.collect_field_imports(field, imports)
326326
if field.optional or field.ref:
327327
imports.add("org.apache.fory.annotation.ForyField")
328328

@@ -422,7 +422,12 @@ def generate_field(self, field: Field) -> List[str]:
422422
lines.append(f"@ForyField({', '.join(annotations)})")
423423

424424
# Field type
425-
java_type = self.generate_type(field.field_type, field.optional)
425+
java_type = self.generate_type(
426+
field.field_type,
427+
field.optional,
428+
field.element_optional,
429+
field.element_ref,
430+
)
426431

427432
lines.append(f"private {java_type} {self.to_camel_case(field.name)};")
428433
lines.append("")
@@ -432,7 +437,12 @@ def generate_field(self, field: Field) -> List[str]:
432437
def generate_getter_setter(self, field: Field) -> List[str]:
433438
"""Generate getter and setter for a field."""
434439
lines = []
435-
java_type = self.generate_type(field.field_type, field.optional)
440+
java_type = self.generate_type(
441+
field.field_type,
442+
field.optional,
443+
field.element_optional,
444+
field.element_ref,
445+
)
436446
field_name = self.to_camel_case(field.name)
437447
pascal_name = self.to_pascal_case(field.name)
438448

@@ -450,7 +460,13 @@ def generate_getter_setter(self, field: Field) -> List[str]:
450460

451461
return lines
452462

453-
def generate_type(self, field_type: FieldType, nullable: bool = False) -> str:
463+
def generate_type(
464+
self,
465+
field_type: FieldType,
466+
nullable: bool = False,
467+
element_optional: bool = False,
468+
element_ref: bool = False,
469+
) -> str:
454470
"""Generate Java type string."""
455471
if isinstance(field_type, PrimitiveType):
456472
if nullable and field_type.kind in self.BOXED_MAP:
@@ -463,7 +479,11 @@ def generate_type(self, field_type: FieldType, nullable: bool = False) -> str:
463479
elif isinstance(field_type, ListType):
464480
# Use primitive arrays for numeric types
465481
if isinstance(field_type.element_type, PrimitiveType):
466-
if field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP:
482+
if (
483+
field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP
484+
and not element_optional
485+
and not element_ref
486+
):
467487
return self.PRIMITIVE_ARRAY_MAP[field_type.element_type.kind]
468488
element_type = self.generate_type(field_type.element_type, True)
469489
return f"List<{element_type}>"
@@ -475,7 +495,13 @@ def generate_type(self, field_type: FieldType, nullable: bool = False) -> str:
475495

476496
return "Object"
477497

478-
def collect_imports(self, field_type: FieldType, imports: Set[str]):
498+
def collect_type_imports(
499+
self,
500+
field_type: FieldType,
501+
imports: Set[str],
502+
element_optional: bool = False,
503+
element_ref: bool = False,
504+
):
479505
"""Collect required imports for a field type."""
480506
if isinstance(field_type, PrimitiveType):
481507
if field_type.kind == PrimitiveKind.DATE:
@@ -486,26 +512,37 @@ def collect_imports(self, field_type: FieldType, imports: Set[str]):
486512
elif isinstance(field_type, ListType):
487513
# Primitive arrays don't need List import
488514
if isinstance(field_type.element_type, PrimitiveType):
489-
if field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP:
515+
if (
516+
field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP
517+
and not element_optional
518+
and not element_ref
519+
):
490520
return # No import needed for primitive arrays
491521
imports.add("java.util.List")
492-
self.collect_imports(field_type.element_type, imports)
522+
self.collect_type_imports(field_type.element_type, imports)
493523

494524
elif isinstance(field_type, MapType):
495525
imports.add("java.util.Map")
496-
self.collect_imports(field_type.key_type, imports)
497-
self.collect_imports(field_type.value_type, imports)
526+
self.collect_type_imports(field_type.key_type, imports)
527+
self.collect_type_imports(field_type.value_type, imports)
528+
529+
def collect_field_imports(self, field: Field, imports: Set[str]):
530+
"""Collect imports for a field, including list modifiers."""
531+
self.collect_type_imports(
532+
field.field_type,
533+
imports,
534+
field.element_optional,
535+
field.element_ref,
536+
)
498537

499538
def has_array_field(self, message: Message) -> bool:
500539
"""Check if message has any array fields (byte[] or primitive arrays)."""
501540
for field in message.fields:
502541
if isinstance(field.field_type, PrimitiveType):
503542
if field.field_type.kind == PrimitiveKind.BYTES:
504543
return True
505-
elif isinstance(field.field_type, ListType):
506-
if isinstance(field.field_type.element_type, PrimitiveType):
507-
if field.field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP:
508-
return True
544+
elif self.is_primitive_array_field(field):
545+
return True
509546
return False
510547

511548
def is_primitive_array_field(self, field: Field) -> bool:
@@ -514,7 +551,11 @@ def is_primitive_array_field(self, field: Field) -> bool:
514551
return field.field_type.kind == PrimitiveKind.BYTES
515552
if isinstance(field.field_type, ListType):
516553
if isinstance(field.field_type.element_type, PrimitiveType):
517-
return field.field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP
554+
return (
555+
field.field_type.element_type.kind in self.PRIMITIVE_ARRAY_MAP
556+
and not field.element_optional
557+
and not field.element_ref
558+
)
518559
return False
519560

520561
def generate_equals_method(self, message: Message) -> List[str]:

0 commit comments

Comments
 (0)