@@ -225,6 +225,29 @@ def to_python_value(self, ctx: FlyteContext, lv: Literal, expected_python_type:
225225 f"Conversion to python value expected type { expected_python_type } from literal not implemented"
226226 )
227227
228+ def schema_match (self , schema : dict ) -> bool :
229+ """Check if a JSON schema fragment matches this transformer's python_type.
230+
231+ For BaseModel subclasses, automatically compares the schema's title, type, and
232+ required fields against the type's own JSON schema. For other types, returns
233+ False by default — override if needed.
234+ """
235+ if not isinstance (schema , dict ):
236+ return False
237+ try :
238+ from pydantic import BaseModel
239+
240+ if hasattr (self .python_type , "model_json_schema" ) and self .python_type is not BaseModel :
241+ this_schema = self .python_type .model_json_schema () # type: ignore[attr-defined]
242+ return (
243+ schema .get ("title" ) == this_schema .get ("title" )
244+ and schema .get ("type" ) == this_schema .get ("type" )
245+ and set (schema .get ("required" , [])) == set (this_schema .get ("required" , []))
246+ )
247+ except Exception :
248+ pass
249+ return False
250+
228251 def from_binary_idl (self , binary_idl_object : Binary , expected_python_type : Type [T ]) -> Optional [T ]:
229252 """
230253 This function primarily handles deserialization for untyped dicts, dataclasses, Pydantic BaseModels, and attribute access.`
@@ -1180,6 +1203,12 @@ def _handle_json_schema_property(
11801203 attr_union_type = reduce (lambda x , y : typing .Union [x , y ], attr_types )
11811204 return (property_key , attr_union_type ) # type: ignore
11821205
1206+ # Handle $ref (e.g. when anyOf recurses with a $ref variant like Optional[Inner])
1207+ # v1 iterates all properties (not just required), so $ref inside anyOf must be resolved here
1208+ if property_val .get ("$ref" ):
1209+ resolved_type = _get_element_type (property_val , schema )
1210+ return (property_key , resolved_type )
1211+
11831212 # Handle enum
11841213 if property_val .get ("enum" ):
11851214 property_type = "enum"
@@ -1203,6 +1232,9 @@ def _handle_json_schema_property(
12031232 elif property_val .get ("title" ):
12041233 # For nested dataclass
12051234 sub_schema_name = property_val ["title" ]
1235+ matched_type = _match_registered_type_from_schema (property_val )
1236+ if matched_type is not None :
1237+ return (property_key , typing .cast (GenericAlias , matched_type ))
12061238 return (
12071239 property_key ,
12081240 typing .cast (GenericAlias , convert_mashumaro_json_schema_to_python_class (property_val , sub_schema_name )),
@@ -1217,6 +1249,14 @@ def _handle_json_schema_property(
12171249 return (property_key , _get_element_type (property_val , schema )) # type: ignore
12181250
12191251
1252+ def _match_registered_type_from_schema (schema : dict ) -> typing .Optional [type ]:
1253+ """Check if a JSON schema fragment matches any registered TypeTransformer."""
1254+ for transformer in TypeEngine ._REGISTRY .values (): # type: ignore[misc]
1255+ if transformer .schema_match (schema ):
1256+ return transformer .python_type
1257+ return None
1258+
1259+
12201260def generate_attribute_list_from_dataclass_json_mixin (
12211261 schema : typing .Dict [str , typing .Any ],
12221262 schema_name : typing .Any ,
@@ -1233,6 +1273,15 @@ def generate_attribute_list_from_dataclass_json_mixin(
12331273 defs = schema .get ("$defs" , schema .get ("definitions" , {}))
12341274 if ref_name in defs :
12351275 ref_schema = defs [ref_name ].copy ()
1276+ # Check if the $ref points to an enum definition (no properties)
1277+ if ref_schema .get ("enum" ):
1278+ attribute_list .append ((property_key , str ))
1279+ continue
1280+ # Check if the $ref matches a registered custom type
1281+ matched_type = _match_registered_type_from_schema (ref_schema )
1282+ if matched_type is not None :
1283+ attribute_list .append ((property_key , typing .cast (GenericAlias , matched_type )))
1284+ continue
12361285 # Include $defs so nested models can resolve their own $refs
12371286 if "$defs" not in ref_schema and defs :
12381287 ref_schema ["$defs" ] = defs
@@ -2531,32 +2580,53 @@ def convert_mashumaro_json_schema_to_python_class(schema: dict, schema_name: typ
25312580def _get_element_type (
25322581 element_property : typing .Dict [str , typing .Any ], schema : typing .Optional [typing .Dict [str , typing .Any ]] = None
25332582) -> Type :
2534- # Handle $ref for nested models in arrays (e.g., List[NestedModel])
2535- # Pydantic generates JSON schema like: {'items': {'$ref': '#/$defs/NestedModel'}, 'type': 'array'}
2536- if element_property .get ("$ref" ):
2537- if schema is None :
2538- raise ValueError (f"Cannot resolve $ref '{ element_property ['$ref' ]} ' without schema context" )
2539- ref_path = element_property ["$ref" ]
2540- ref_name = ref_path .split ("/" )[- 1 ]
2583+ # Handle $ref for nested models and enums
2584+ # Ensure that the element is actually a $ref and we have the entire schema to look up
2585+ if element_property .get ("$ref" ) and schema is not None :
2586+ ref_name = element_property ["$ref" ].split ("/" )[- 1 ]
25412587 defs = schema .get ("$defs" , schema .get ("definitions" , {}))
2588+ # Look up for ref_name in the defs defined in the schema
25422589 if ref_name in defs :
2590+ # Don't mutate the original schema
25432591 ref_schema = defs [ref_name ].copy ()
2544- # Include $defs so nested models can resolve their own $refs
2592+ # Guard the nested enum elements inside containers
2593+ if ref_schema .get ("enum" ):
2594+ return str
2595+ # Check if the $ref matches a registered custom type
2596+ if (matched_type := _match_registered_type_from_schema (ref_schema )) is not None :
2597+ return matched_type
2598+ # if defs not in the schema, they need to be propagated into the resolved schema
25452599 if "$defs" not in ref_schema and defs :
25462600 ref_schema ["$defs" ] = defs
2601+ # build a dataclass from the resolved schema
25472602 return convert_mashumaro_json_schema_to_python_class (ref_schema , ref_name )
2548- else :
2549- raise ValueError (f"Cannot find definition for $ref '{ ref_path } ' in schema" )
2550-
2551- element_type = (
2552- [e_property ["type" ] for e_property in element_property ["anyOf" ]] # type: ignore
2553- if element_property .get ("anyOf" )
2554- else element_property ["type" ]
2555- )
2556- element_format = element_property ["format" ] if "format" in element_property else None
2603+ # default to str on failure. Shouldn't happen with valid pydantic schemas
2604+ return str
25572605
2606+ # Handle anyOf (e.g. Optional[int], Optional[Inner])
2607+ # Early return block replacing the previous list comprehension which would fail when an anyOf reference was a $ref
2608+ # (meaning no $type key).
2609+ if element_property .get ("anyOf" ):
2610+ # Separate non null variants. Note a $ref variant would have type None NOT null. A {"type": "null"} variant is
2611+ # filtered out.
2612+ variants = element_property ["anyOf" ]
2613+ non_null = [v for v in variants if v .get ("type" ) != "null" ]
2614+ # Detect if this is an Optional pattern here
2615+ has_null = len (non_null ) < len (variants )
2616+ # This recurses on the first non-null variant which would handle the $ref, nested_arrays, nested_objects...
2617+ # anything. Wrap it in Optional if has_null.
2618+ if non_null :
2619+ inner_type = _get_element_type (non_null [0 ], schema )
2620+ return typing .Optional [inner_type ] if has_null else inner_type # type: ignore
2621+ # return None if all types are None
2622+ return type (None )
2623+
2624+ element_type = element_property .get ("type" , "string" )
2625+ element_format = element_property .get ("format" )
2626+
2627+ # Handle marshmallow-style Optional types where "type" is a list e.g. ["integer", "null"]
2628+ # This pattern is not used by Pydantic (which uses anyOf) but is used by the marshmallow/DataClassJsonMixin path in v1
25582629 if isinstance (element_type , list ):
2559- # Element type of Optional[int] is [integer, None]
25602630 return typing .Optional [_get_element_type ({"type" : element_type [0 ]}, schema )] # type: ignore
25612631
25622632 if element_type == "string" :
@@ -2570,6 +2640,16 @@ def _get_element_type(
25702640 return int
25712641 else :
25722642 return float
2643+ # Recursively discover the types when an array or object element type is discovered
2644+ elif element_type == "array" :
2645+ return typing .List [_get_element_type (element_property .get ("items" , {}), schema )] # type: ignore
2646+ elif element_type == "object" :
2647+ if element_property .get ("additionalProperties" ):
2648+ return typing .Dict [str , _get_element_type (element_property ["additionalProperties" ], schema )] # type: ignore
2649+ return dict
2650+ # Corner case - practically useless but List[None] is a legal Python type
2651+ elif element_type == "null" :
2652+ return type (None )
25732653 return str
25742654
25752655
0 commit comments