diff --git a/datafusion/expr/src/type_coercion/functions.rs b/datafusion/expr/src/type_coercion/functions.rs index d5cb98a46ef43..99d4090a59a74 100644 --- a/datafusion/expr/src/type_coercion/functions.rs +++ b/datafusion/expr/src/type_coercion/functions.rs @@ -130,7 +130,7 @@ pub fn fields_with_udf( let valid_types = get_valid_types_with_udf(type_signature, ¤t_types, func)?; if valid_types .iter() - .any(|data_type| data_type == ¤t_types) + .any(|data_type| data_types_match(data_type, ¤t_types)) { return Ok(current_fields.to_vec()); } @@ -236,7 +236,7 @@ pub fn data_types( get_valid_types(function_name.as_ref(), type_signature, current_types)?; if valid_types .iter() - .any(|data_type| data_type == current_types) + .any(|data_type| data_types_match(data_type, current_types)) { return Ok(current_types.to_vec()); } @@ -307,6 +307,51 @@ fn try_coerce_types( ) } +fn data_types_match(valid_types: &[DataType], current_types: &[DataType]) -> bool { + fn field_matches(valid: &FieldRef, current: &FieldRef, compare_name: bool) -> bool { + (!compare_name || valid.name() == current.name()) + && valid.is_nullable() == current.is_nullable() + && data_type_matches(valid.data_type(), current.data_type()) + } + + // Allow nested types that differ only by wrapper field names chosen by an + // Arrow producer (for example parquet lists/maps), but keep struct field + // names and ordering exact because some kernels depend on them at runtime. + fn data_type_matches(valid: &DataType, current: &DataType) -> bool { + match (valid, current) { + (valid, current) if valid == current => true, + (DataType::List(valid), DataType::List(current)) + | (DataType::LargeList(valid), DataType::LargeList(current)) + | (DataType::ListView(valid), DataType::ListView(current)) + | (DataType::LargeListView(valid), DataType::LargeListView(current)) => { + field_matches(valid, current, false) + } + ( + DataType::FixedSizeList(valid, valid_size), + DataType::FixedSizeList(current, current_size), + ) => valid_size == current_size && field_matches(valid, current, false), + ( + DataType::Map(valid, valid_sorted), + DataType::Map(current, current_sorted), + ) => valid_sorted == current_sorted && field_matches(valid, current, false), + (DataType::Struct(valid), DataType::Struct(current)) => { + valid.len() == current.len() + && valid + .iter() + .zip(current.iter()) + .all(|(valid, current)| field_matches(valid, current, true)) + } + _ => false, + } + } + + valid_types.len() == current_types.len() + && valid_types + .iter() + .zip(current_types) + .all(|(valid_type, current_type)| data_type_matches(valid_type, current_type)) +} + fn get_valid_types_with_udf( signature: &TypeSignature, current_types: &[DataType], @@ -757,6 +802,10 @@ fn maybe_data_types( for (i, valid_type) in valid_types.iter().enumerate() { let current_type = ¤t_types[i]; + // Keep exact equality here. Some kernels such as `make_array` + // require nested field names/order to match exactly at runtime. + // Structural-equivalence short-circuiting is handled earlier by + // `data_types_match`. if current_type == valid_type { new_type.push(current_type.clone()) } else { @@ -789,6 +838,10 @@ fn maybe_data_types_without_coercion( for (i, valid_type) in valid_types.iter().enumerate() { let current_type = ¤t_types[i]; + // Keep exact equality here. Some kernels such as `make_array` + // require nested field names/order to match exactly at runtime. + // Structural-equivalence short-circuiting is handled earlier by + // `data_types_match`. if current_type == valid_type { new_type.push(current_type.clone()) } else if can_cast_types(current_type, valid_type) { @@ -1044,6 +1097,191 @@ mod tests { } } + #[test] + fn test_maybe_data_types_uses_exact_nested_types() { + let struct_fields = vec![ + Field::new("id", DataType::Utf8, true), + Field::new("prim", DataType::Boolean, true), + ]; + let current_type = DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(struct_fields.clone().into()), + true, + ))); + let valid_type = DataType::List(Arc::new(Field::new( + "element", + DataType::Struct(struct_fields.into()), + true, + ))); + + assert!(current_type.equals_datatype(&valid_type)); + assert_ne!(current_type, valid_type); + assert_eq!( + maybe_data_types(std::slice::from_ref(&valid_type), &[current_type]), + Some(vec![valid_type]) + ); + } + + #[test] + fn test_maybe_data_types_without_coercion_uses_exact_nested_types() { + let valid_type = DataType::Struct( + vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ] + .into(), + ); + let current_type = DataType::Struct( + vec![ + Field::new("b", DataType::Int64, true), + Field::new("a", DataType::Int64, true), + ] + .into(), + ); + + assert!(current_type.equals_datatype(&valid_type)); + assert_ne!(current_type, valid_type); + assert_eq!( + maybe_data_types_without_coercion( + std::slice::from_ref(&valid_type), + &[current_type], + ), + Some(vec![valid_type]) + ); + } + + #[test] + fn test_data_types_match_ignores_list_field_name() { + let struct_fields = vec![ + Field::new("id", DataType::Utf8, true), + Field::new("prim", DataType::Boolean, true), + ]; + let current_type = DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(struct_fields.clone().into()), + true, + ))); + let valid_type = DataType::List(Arc::new(Field::new( + "element", + DataType::Struct(struct_fields.into()), + true, + ))); + + assert!(data_types_match(&[valid_type], &[current_type])); + } + + #[test] + fn test_data_types_match_ignores_list_view_field_name() { + let struct_fields = vec![ + Field::new("id", DataType::Utf8, true), + Field::new("prim", DataType::Boolean, true), + ]; + let current_type = DataType::ListView(Arc::new(Field::new( + "item", + DataType::Struct(struct_fields.clone().into()), + true, + ))); + let valid_type = DataType::ListView(Arc::new(Field::new( + "element", + DataType::Struct(struct_fields.into()), + true, + ))); + + assert!(data_types_match(&[valid_type], &[current_type])); + } + + #[test] + fn test_data_types_match_recurses_through_struct_fields() { + let current_type = DataType::Struct( + vec![Field::new( + "a", + DataType::List(Arc::new(Field::new("item", DataType::Int64, true))), + true, + )] + .into(), + ); + let valid_type = DataType::Struct( + vec![Field::new( + "a", + DataType::List(Arc::new(Field::new("element", DataType::Int64, true))), + true, + )] + .into(), + ); + + assert!(data_types_match(&[valid_type], &[current_type])); + } + + #[test] + fn test_data_types_match_ignores_map_field_name() { + let current_type = DataType::Map( + Arc::new(Field::new( + "entries", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::List(Arc::new(Field::new( + "item", + DataType::Int64, + true, + ))), + true, + ), + ] + .into(), + ), + false, + )), + false, + ); + let valid_type = DataType::Map( + Arc::new(Field::new( + "pairs", + DataType::Struct( + vec![ + Field::new("key", DataType::Utf8, false), + Field::new( + "value", + DataType::List(Arc::new(Field::new( + "element", + DataType::Int64, + true, + ))), + true, + ), + ] + .into(), + ), + false, + )), + false, + ); + + assert!(data_types_match(&[valid_type], &[current_type])); + } + + #[test] + fn test_data_types_match_respects_struct_field_order() { + let valid_type = DataType::Struct( + vec![ + Field::new("a", DataType::Int64, true), + Field::new("b", DataType::Int64, true), + ] + .into(), + ); + let current_type = DataType::Struct( + vec![ + Field::new("b", DataType::Int64, true), + Field::new("a", DataType::Int64, true), + ] + .into(), + ); + + assert!(!data_types_match(&[valid_type], &[current_type])); + } + #[test] fn test_get_valid_types_numeric() -> Result<()> { let get_valid_types_flatten = @@ -1223,6 +1461,36 @@ mod tests { Ok(()) } + #[test] + fn test_fields_with_udf_preserves_equivalent_nested_types() -> Result<()> { + let struct_fields = vec![ + Field::new("id", DataType::Utf8, true), + Field::new("prim", DataType::Boolean, true), + ]; + let current_type = DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(struct_fields.clone().into()), + true, + ))); + let signature_type = DataType::List(Arc::new(Field::new( + "element", + DataType::Struct(struct_fields.into()), + true, + ))); + + assert!(current_type.equals_datatype(&signature_type)); + + let current_fields = vec![Arc::new(Field::new("field", current_type, true))]; + let coerced_fields = fields_with_udf( + ¤t_fields, + &MockUdf(Signature::exact(vec![signature_type], Volatility::Stable)), + )?; + + assert_eq!(coerced_fields, current_fields); + + Ok(()) + } + #[test] fn test_nested_wildcard_fixed_size_lists() -> Result<()> { let type_into = DataType::FixedSizeList( diff --git a/datafusion/sqllogictest/test_files/spark/array/array.slt b/datafusion/sqllogictest/test_files/spark/array/array.slt index 79dca1c10a7d0..3c07b2326adf7 100644 --- a/datafusion/sqllogictest/test_files/spark/array/array.slt +++ b/datafusion/sqllogictest/test_files/spark/array/array.slt @@ -85,3 +85,43 @@ query ? SELECT array(arrow_cast(array(1,2), 'LargeList(Int64)'), array(3)); ---- [[1, 2], [3]] + +query TT +EXPLAIN SELECT array_element(arrow_cast([1, 2], 'LargeList(Int64)'), 1); +---- +logical_plan +01)Projection: Int64(1) AS arrow_cast(make_array(Int64(1),Int64(2)),Utf8("LargeList(Int64)"))[Int64(1)] +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[1 as arrow_cast(make_array(Int64(1),Int64(2)),Utf8("LargeList(Int64)"))[Int64(1)]] +02)--PlaceholderRowExec + +query TT +EXPLAIN SELECT get_field(array_element(MAP {'ECID': [{id: 1, prim: true}, {id: 2, prim: false}]}['ECID'], 1), 'id'); +---- +logical_plan +01)Projection: Int64(1) AS map(make_array(Utf8("ECID")),make_array(make_array(named_struct(Utf8("id"),Int64(1),Utf8("prim"),Boolean(true)),named_struct(Utf8("id"),Int64(2),Utf8("prim"),Boolean(false)))))[ECID][Int64(1)][id] +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[1 as map(make_array(Utf8("ECID")),make_array(make_array(named_struct(Utf8("id"),Int64(1),Utf8("prim"),Boolean(true)),named_struct(Utf8("id"),Int64(2),Utf8("prim"),Boolean(false)))))[ECID][Int64(1)][id]] +02)--PlaceholderRowExec + +query TT +EXPLAIN SELECT array_element(get_field({'a': [1, 2]}, 'a'), 1); +---- +logical_plan +01)Projection: Int64(1) AS named_struct(Utf8("a"),make_array(Int64(1),Int64(2)))[a][Int64(1)] +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[1 as named_struct(Utf8("a"),make_array(Int64(1),Int64(2)))[a][Int64(1)]] +02)--PlaceholderRowExec + +query TT +EXPLAIN SELECT array_element(arrow_cast(make_array(1, 2), 'FixedSizeList(2, Int64)'), 1); +---- +logical_plan +01)Projection: Int64(1) AS arrow_cast(make_array(Int64(1),Int64(2)),Utf8("FixedSizeList(2, Int64)"))[Int64(1)] +02)--EmptyRelation: rows=1 +physical_plan +01)ProjectionExec: expr=[1 as arrow_cast(make_array(Int64(1),Int64(2)),Utf8("FixedSizeList(2, Int64)"))[Int64(1)]] +02)--PlaceholderRowExec