diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index bf2558f313069..cdd6215d08e2f 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -17,8 +17,11 @@ use crate::error::{_plan_err, Result}; use arrow::{ - array::{Array, ArrayRef, StructArray, new_null_array}, - compute::{CastOptions, cast_with_options}, + array::{ + Array, ArrayRef, DictionaryArray, GenericListArray, GenericListViewArray, + StructArray, downcast_integer, new_null_array, + }, + compute::{CastOptions, can_cast_types, cast_with_options}, datatypes::{DataType, DataType::Struct, Field, FieldRef}, }; use std::{collections::HashSet, sync::Arc}; @@ -80,14 +83,17 @@ fn cast_struct_column( match source_child_opt { Some(source_child_col) => { - let adapted_child = - cast_column(source_child_col, target_child_field, cast_options) - .map_err(|e| { - e.context(format!( - "While casting struct field '{}'", - target_child_field.name() - )) - })?; + let adapted_child = cast_column( + source_child_col, + target_child_field.data_type(), + cast_options, + ) + .map_err(|e| { + e.context(format!( + "While casting struct field '{}'", + target_child_field.name() + )) + })?; arrays.push(adapted_child); } None => { @@ -127,18 +133,17 @@ fn cast_struct_column( /// ``` /// use arrow::array::{ArrayRef, Int64Array}; /// use arrow::compute::CastOptions; -/// use arrow::datatypes::{DataType, Field}; +/// use arrow::datatypes::DataType; /// use datafusion_common::nested_struct::cast_column; /// use std::sync::Arc; /// /// let source: ArrayRef = Arc::new(Int64Array::from(vec![1, i64::MAX])); -/// let target = Field::new("ints", DataType::Int32, true); /// // Permit lossy conversions by producing NULL on overflow instead of erroring /// let options = CastOptions { /// safe: true, /// ..Default::default() /// }; -/// let result = cast_column(&source, &target, &options).unwrap(); +/// let result = cast_column(&source, &DataType::Int32, &options).unwrap(); /// assert!(result.is_null(1)); /// ``` /// @@ -151,7 +156,7 @@ fn cast_struct_column( /// /// # Arguments /// * `source_col` - The source array to cast -/// * `target_field` - The target field definition (including type and metadata) +/// * `target_type` - The target data type to cast to /// * `cast_options` - Options that govern strictness and formatting of the cast /// /// # Returns @@ -165,18 +170,139 @@ fn cast_struct_column( /// - Invalid data type combinations are encountered pub fn cast_column( source_col: &ArrayRef, - target_field: &Field, + target_type: &DataType, cast_options: &CastOptions, ) -> Result { - match target_field.data_type() { - Struct(target_fields) => { + match (source_col.data_type(), target_type) { + (_, Struct(target_fields)) => { cast_struct_column(source_col, target_fields, cast_options) } - _ => Ok(cast_with_options( + (DataType::List(_), DataType::List(target_inner)) => { + cast_list_column::(source_col, target_inner, cast_options) + } + (DataType::LargeList(_), DataType::LargeList(target_inner)) => { + cast_list_column::(source_col, target_inner, cast_options) + } + (DataType::ListView(_), DataType::ListView(target_inner)) => { + cast_list_view_column::(source_col, target_inner, cast_options) + } + (DataType::LargeListView(_), DataType::LargeListView(target_inner)) => { + cast_list_view_column::(source_col, target_inner, cast_options) + } + ( + DataType::Dictionary(source_key_type, _), + DataType::Dictionary(target_key_type, target_value_type), + ) => cast_dictionary_column( source_col, - target_field.data_type(), + source_key_type, + target_key_type, + target_value_type, cast_options, - )?), + ), + _ => Ok(cast_with_options(source_col, target_type, cast_options)?), + } +} + +fn cast_list_column( + source_col: &ArrayRef, + target_inner_field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let source_list = source_col + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + crate::error::DataFusionError::Plan(format!( + "Expected list array but got {}", + source_col.data_type() + )) + })?; + + let cast_values = cast_column( + source_list.values(), + target_inner_field.data_type(), + cast_options, + )?; + + let result = GenericListArray::::new( + Arc::clone(target_inner_field), + source_list.offsets().clone(), + cast_values, + source_list.nulls().cloned(), + ); + Ok(Arc::new(result)) +} + +fn cast_list_view_column( + source_col: &ArrayRef, + target_inner_field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let source_list = source_col + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + crate::error::DataFusionError::Plan(format!( + "Expected list view array but got {}", + source_col.data_type() + )) + })?; + + let cast_values = cast_column( + source_list.values(), + target_inner_field.data_type(), + cast_options, + )?; + + let result = GenericListViewArray::::try_new( + Arc::clone(target_inner_field), + source_list.offsets().clone(), + source_list.sizes().clone(), + cast_values, + source_list.nulls().cloned(), + )?; + Ok(Arc::new(result)) +} + +fn cast_dictionary_column( + source_col: &ArrayRef, + source_key_type: &DataType, + target_key_type: &DataType, + target_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + // Dispatch on source key type to access keys/values, then recursively + // cast values. Rebuild with the source key type first. + macro_rules! cast_dict_values { + ($t:ty) => {{ + let source_dict = source_col + .as_any() + .downcast_ref::>() + .expect("downcast must succeed"); + let cast_values = + cast_column(source_dict.values(), target_value_type, cast_options)?; + Ok(Arc::new(DictionaryArray::<$t>::new( + source_dict.keys().clone(), + cast_values, + )) as ArrayRef) + }}; + } + + let result: Result = downcast_integer! { + source_key_type => (cast_dict_values), + k => _plan_err!("Unsupported dictionary key type: {k}") + }; + let result = result?; + + // If key types differ, delegate key casting to Arrow. + if source_key_type != target_key_type { + let target_dict_type = DataType::Dictionary( + Box::new(target_key_type.clone()), + Box::new(target_value_type.clone()), + ); + Ok(cast_with_options(&result, &target_dict_type, cast_options)?) + } else { + Ok(result) } } @@ -281,31 +407,84 @@ fn validate_field_compatibility( ); } - // Check if the matching field types are compatible - match (source_field.data_type(), target_field.data_type()) { - // Recursively validate nested structs + validate_data_type_compatibility( + target_field.name(), + source_field.data_type(), + target_field.data_type(), + ) +} + +/// Validates that `source_type` can be cast to `target_type`, recursively +/// handling container types that wrap structs. +pub fn validate_data_type_compatibility( + field_name: &str, + source_type: &DataType, + target_type: &DataType, +) -> Result<()> { + match (source_type, target_type) { (Struct(source_nested), Struct(target_nested)) => { validate_struct_compatibility(source_nested, target_nested)?; } - // For non-struct types, use the existing castability check + (DataType::List(s), DataType::List(t)) + | (DataType::LargeList(s), DataType::LargeList(t)) + | (DataType::ListView(s), DataType::ListView(t)) + | (DataType::LargeListView(s), DataType::LargeListView(t)) => { + validate_field_compatibility(s, t)?; + } + (DataType::Dictionary(s_key, s_val), DataType::Dictionary(t_key, t_val)) => { + if !can_cast_types(s_key, t_key) { + return _plan_err!( + "Cannot cast dictionary key type {} to {} for field '{}'", + s_key, + t_key, + field_name + ); + } + validate_data_type_compatibility(field_name, s_val, t_val)?; + } _ => { - if !arrow::compute::can_cast_types( - source_field.data_type(), - target_field.data_type(), - ) { + if !can_cast_types(source_type, target_type) { return _plan_err!( "Cannot cast struct field '{}' from type {} to type {}", - target_field.name(), - source_field.data_type(), - target_field.data_type() + field_name, + source_type, + target_type ); } } } - Ok(()) } +/// Returns true if casting from `source_type` to `target_type` requires +/// name-based nested struct casting logic, rather than Arrow's standard cast. +/// +/// This is the case when both types are struct types, or both are the same +/// container type (List, LargeList, ListView, LargeListView, Dictionary) wrapping +/// types that recursively contain structs. +/// +/// Use this predicate at both planning time (to decide whether to apply struct +/// compatibility validation) and execution time (to decide whether to route +/// through [`cast_column`] instead of Arrow's generic cast). +pub fn requires_nested_struct_cast( + source_type: &DataType, + target_type: &DataType, +) -> bool { + match (source_type, target_type) { + (Struct(_), Struct(_)) => true, + (DataType::List(s), DataType::List(t)) + | (DataType::LargeList(s), DataType::LargeList(t)) + | (DataType::ListView(s), DataType::ListView(t)) + | (DataType::LargeListView(s), DataType::LargeListView(t)) => { + requires_nested_struct_cast(s.data_type(), t.data_type()) + } + (DataType::Dictionary(_, s_val), DataType::Dictionary(_, t_val)) => { + requires_nested_struct_cast(s_val, t_val) + } + _ => false, + } +} + /// Check if two field lists have at least one common field by name. /// /// This is useful for validating struct compatibility when casting between structs, @@ -325,15 +504,14 @@ pub fn has_one_of_more_common_fields( #[cfg(test)] mod tests { - use super::*; use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS}; use arrow::{ array::{ - BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, - MapBuilder, NullArray, StringArray, StringBuilder, + BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, ListViewArray, + MapArray, MapBuilder, NullArray, StringArray, StringBuilder, }, - buffer::NullBuffer, + buffer::{NullBuffer, ScalarBuffer}, datatypes::{DataType, Field, FieldRef, Int32Type}, }; /// Macro to extract and downcast a column from a StructArray @@ -376,7 +554,9 @@ mod tests { fn test_cast_simple_column() { let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; let target_field = field("ints", DataType::Int64); - let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let result = + cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.len(), 3); assert_eq!(result.value(0), 1); @@ -394,14 +574,15 @@ mod tests { safe: false, ..DEFAULT_CAST_OPTIONS }; - assert!(cast_column(&source, &target_field, &safe_opts).is_err()); + assert!(cast_column(&source, target_field.data_type(), &safe_opts).is_err()); let unsafe_opts = CastOptions { // safe: true - return Null for failure safe: true, ..DEFAULT_CAST_OPTIONS }; - let result = cast_column(&source, &target_field, &unsafe_opts).unwrap(); + let result = + cast_column(&source, target_field.data_type(), &unsafe_opts).unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.value(0), 1); assert!(result.is_null(1)); @@ -422,7 +603,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(struct_array.fields().len(), 2); let a_result = get_column_as!(&struct_array, "a", Int32Array); @@ -440,7 +622,8 @@ mod tests { let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; let target_field = struct_field("s", vec![field("a", DataType::Int32)]); - let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Cannot cast column of type")); @@ -460,7 +643,8 @@ mod tests { let target_field = struct_field("s", vec![field("a", DataType::Int32)]); - let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Cannot cast struct field 'a'")); @@ -557,7 +741,8 @@ mod tests { let target_field = struct_field("s", vec![field("a", DataType::Int64)]); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(struct_array.null_count(), 1); assert!(struct_array.is_valid(0)); @@ -767,7 +952,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let outer = result.as_any().downcast_ref::().unwrap(); let inner = get_column_as!(&outer, "inner", StructArray); assert_eq!(inner.fields().len(), 3); @@ -802,7 +988,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let outer = result.as_any().downcast_ref::().unwrap(); let inner = get_column_as!(&outer, "inner", StructArray); assert_eq!(inner.len(), 2); @@ -883,7 +1070,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let arr = get_column_as!(&struct_array, "arr", ListArray); @@ -922,7 +1110,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let b_col = get_column_as!(&struct_array, "b", Int64Array); @@ -951,7 +1140,8 @@ mod tests { vec![field("a", DataType::Int64), field("b", DataType::Utf8)], ); - let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert_contains!(error_msg, "no field name overlap"); @@ -974,7 +1164,8 @@ mod tests { ); // Should fail because 'b' is non-nullable but missing from source - let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let err = result.unwrap_err(); assert!( @@ -999,7 +1190,8 @@ mod tests { // Should succeed - 'b' is nullable so can be filled with NULL let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let a_col = get_column_as!(&struct_array, "a", Int32Array); @@ -1010,4 +1202,138 @@ mod tests { assert!(b_col.is_null(0)); assert!(b_col.is_null(1)); } + + #[test] + fn test_validate_dictionary_value_evolution() { + let source_inner = struct_type(vec![field("a", DataType::Int32)]); + let target_inner = struct_type(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + ]); + let source = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(source_inner)); + let target = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(target_inner)); + assert!(validate_data_type_compatibility("col", &source, &target).is_ok()); + } + + #[test] + fn test_cast_dictionary_struct_value() { + // Build a Dictionary and cast to + // Dictionary (field added, type widened). + let struct_arr = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef, + )]); + // keys: [0, null, 1] mapping into the 2-element struct values array. + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let source_dict = DictionaryArray::::new(keys, Arc::new(struct_arr)); + let source_col: ArrayRef = Arc::new(source_dict); + + let target_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(struct_type(vec![ + field("a", DataType::Int64), + field("b", DataType::Utf8), + ])), + ); + + let result = + cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap(); + let result_dict = result + .as_any() + .downcast_ref::>() + .unwrap(); + + assert!(result_dict.is_valid(0)); + assert!(result_dict.is_null(1)); + assert!(result_dict.is_valid(2)); + + let struct_values = result_dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let a_col = get_column_as!(&struct_values, "a", Int64Array); + assert_eq!(a_col.values(), &[10, 20]); + let b_col = get_column_as!(&struct_values, "b", StringArray); + assert!(b_col.iter().all(|v| v.is_none())); + } + + #[test] + fn test_cast_list_view_struct() { + // Build a ListView and cast to + // ListView. + let struct_arr = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]); + + let source_field = + arc_field("item", struct_type(vec![field("a", DataType::Int32)])); + let target_field = arc_field( + "item", + struct_type(vec![ + field("a", DataType::Int64), + field("b", DataType::Utf8), + ]), + ); + + // Two list-view entries: [0..2] and [2..3] + let list_view = ListViewArray::new( + source_field, + ScalarBuffer::from(vec![0i32, 2]), + ScalarBuffer::from(vec![2i32, 1]), + Arc::new(struct_arr), + None, + ); + let source_col: ArrayRef = Arc::new(list_view); + + let target_type = DataType::ListView(target_field); + + let result = + cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap(); + let result_lv = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result_lv.len(), 2); + + let struct_values = result_lv + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let a_col = get_column_as!(&struct_values, "a", Int64Array); + assert_eq!(a_col.values(), &[1, 2, 3]); + let b_col = get_column_as!(&struct_values, "b", StringArray); + assert!(b_col.iter().all(|v| v.is_none())); + } + + #[test] + fn test_requires_nested_struct_cast() { + let s1 = struct_type(vec![field("a", DataType::Int32)]); + let s2 = struct_type(vec![field("a", DataType::Int64)]); + + assert!(requires_nested_struct_cast(&s1, &s2)); + assert!(requires_nested_struct_cast( + &DataType::List(arc_field("item", s1.clone())), + &DataType::List(arc_field("item", s2.clone())), + )); + assert!(requires_nested_struct_cast( + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s1.clone())), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s2.clone())), + )); + assert!(requires_nested_struct_cast( + &DataType::ListView(arc_field("item", s1)), + &DataType::ListView(arc_field("item", s2)), + )); + + // Non-struct types should return false. + assert!(!requires_nested_struct_cast( + &DataType::Int32, + &DataType::Int64 + )); + assert!(!requires_nested_struct_cast( + &DataType::List(arc_field("item", DataType::Int32)), + &DataType::List(arc_field("item", DataType::Int64)), + )); + } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index d759bbedd9266..2023339a33687 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -3901,20 +3901,16 @@ impl ScalarValue { let scalar_array = self.to_array()?; - // For struct types, use name-based casting logic that matches fields by name - // and recursively casts nested structs. The field name wrapper is arbitrary - // since cast_column only uses the DataType::Struct field definitions inside. - let cast_arr = match target_type { - DataType::Struct(_) => { - // Field name is unused; only the struct's inner field names matter - let target_field = Field::new("_", target_type.clone(), true); - crate::nested_struct::cast_column( - &scalar_array, - &target_field, - cast_options, - )? - } - _ => cast_with_options(&scalar_array, target_type, cast_options)?, + // For types that contain structs (including nested inside Lists, Dictionaries, + // etc.), use name-based casting logic that matches struct fields by name and + // recursively casts nested structs. + let cast_arr = if crate::nested_struct::requires_nested_struct_cast( + scalar_array.data_type(), + target_type, + ) { + crate::nested_struct::cast_column(&scalar_array, target_type, cast_options)? + } else { + cast_with_options(&scalar_array, target_type, cast_options)? }; ScalarValue::try_from_array(&cast_arr, 0) diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 1aa42470a1481..bc6b8177ab3cf 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::{ array::{Array, ArrayRef, Date32Array, Date64Array, NullArray}, compute::{CastOptions, kernels, max, min}, - datatypes::{DataType, Field}, + datatypes::DataType, util::pretty::pretty_format_columns, }; use datafusion_common::internal_datafusion_err; @@ -313,24 +313,18 @@ fn cast_array_by_name( return Ok(Arc::clone(array)); } - match cast_type { - DataType::Struct(_) => { - // Field name is unused; only the struct's inner field names matter - let target_field = Field::new("_", cast_type.clone(), true); - datafusion_common::nested_struct::cast_column( - array, - &target_field, - cast_options, - ) - } - _ => { - ensure_date_array_timestamp_bounds(array, cast_type)?; - Ok(kernels::cast::cast_with_options( - array, - cast_type, - cast_options, - )?) - } + if datafusion_common::nested_struct::requires_nested_struct_cast( + array.data_type(), + cast_type, + ) { + datafusion_common::nested_struct::cast_column(array, cast_type, cast_options) + } else { + ensure_date_array_timestamp_bounds(array, cast_type)?; + Ok(kernels::cast::cast_with_options( + array, + cast_type, + cast_options, + )?) } } diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index a2a45cbdfe7aa..f2677dbe3f0cd 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -25,12 +25,11 @@ use std::hash::Hash; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; use datafusion_common::{ - Result, ScalarValue, exec_err, + DataFusionError, Result, ScalarValue, exec_err, metadata::FieldMetadata, - nested_struct::validate_struct_compatibility, + nested_struct::validate_data_type_compatibility, tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_functions::core::getfield::GetFieldFunc; @@ -487,31 +486,18 @@ impl DefaultPhysicalExprAdapterRewriter { physical_field: FieldRef, logical_field: &Field, ) -> Result>> { - // For struct types, use validate_struct_compatibility which handles: - // - Missing fields in source (filled with nulls) - // - Extra fields in source (ignored) - // - Recursive validation of nested structs - // For non-struct types, use Arrow's can_cast_types - match (physical_field.data_type(), logical_field.data_type()) { - (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => { - validate_struct_compatibility( - physical_fields.as_ref(), - logical_fields.as_ref(), - )?; - } - _ => { - let is_compatible = - can_cast_types(physical_field.data_type(), logical_field.data_type()); - if !is_compatible { - return exec_err!( - "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", + validate_data_type_compatibility( + column.name(), + physical_field.data_type(), + logical_field.data_type(), + ) + .map_err(|e| + DataFusionError::Execution(format!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}", column.name(), physical_field.data_type(), logical_field.data_type() - ); - } - } - } + )))?; let cast_expr = Arc::new(CastColumnExpr::new( Arc::new(column), @@ -663,8 +649,8 @@ impl BatchAdapter { mod tests { use super::*; use arrow::array::{ - BooleanArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions, - StringArray, StringViewArray, StructArray, + Array, BooleanArray, GenericListArray, Int32Array, Int64Array, RecordBatch, + RecordBatchOptions, StringArray, StringViewArray, StructArray, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::{Result, ScalarValue, assert_contains, record_batch}; @@ -1289,6 +1275,142 @@ mod tests { assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]); } + /// Test that List columns are properly adapted with struct evolution. + #[test] + fn test_adapt_list_struct_batches() { + // Physical: List<{id: Int32, name: Utf8}> + let physical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + + let struct_array = StructArray::new( + physical_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + Arc::new(StringArray::from(vec![ + Some("alice"), + None, + Some("charlie"), + ])) as _, + ], + None, + ); + + // One list element per row + let item_field = Arc::new(Field::new( + "item", + DataType::Struct(physical_struct_fields.clone()), + true, + )); + let offsets = + arrow::buffer::OffsetBuffer::from_lengths(vec![1usize; struct_array.len()]); + let list_array = GenericListArray::::new( + item_field, + offsets, + Arc::new(struct_array), + None, + ); + + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(physical_struct_fields), + true, + ))), + false, + )])); + + let physical_batch = RecordBatch::try_new( + Arc::clone(&physical_schema), + vec![Arc::new(list_array)], + ) + .unwrap(); + + // Logical: List<{id: Int64, name: Utf8View, extra: Boolean}> + let logical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + Field::new("extra", DataType::Boolean, true), + ] + .into(); + + let logical_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(logical_struct_fields.clone()), + true, + ))), + false, + )])); + + let projection = vec![col("data", &logical_schema).unwrap()]; + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) + .unwrap(); + + let adapted_projection = projection + .into_iter() + .map(|expr| adapter.rewrite(expr).unwrap()) + .collect_vec(); + + let adapted_schema = Arc::new(Schema::new( + adapted_projection + .iter() + .map(|expr| expr.return_field(&physical_schema).unwrap()) + .collect_vec(), + )); + + let res = batch_project( + adapted_projection, + &physical_batch, + Arc::clone(&adapted_schema), + ) + .unwrap(); + + assert_eq!(res.num_columns(), 1); + + let result_list = res + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check each list element contains the evolved struct + assert_eq!(result_list.len(), 3); + let flat_structs = result_list + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + let id_col = flat_structs.column_by_name("id").unwrap(); + assert_eq!(id_col.data_type(), &DataType::Int64); + let id_values = id_col.as_any().downcast_ref::().unwrap(); + assert_eq!( + id_values.iter().collect_vec(), + vec![Some(1), Some(2), Some(3)] + ); + + let name_col = flat_structs.column_by_name("name").unwrap(); + assert_eq!(name_col.data_type(), &DataType::Utf8View); + let name_values = name_col.as_any().downcast_ref::().unwrap(); + assert_eq!( + name_values.iter().collect_vec(), + vec![Some("alice"), None, Some("charlie")] + ); + + let extra_col = flat_structs.column_by_name("extra").unwrap(); + assert_eq!(extra_col.data_type(), &DataType::Boolean); + let extra_values = extra_col.as_any().downcast_ref::().unwrap(); + assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]); + } + #[test] fn test_try_rewrite_struct_field_access() { // Test the core logic of try_rewrite_struct_field_access diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 5a80daf663a36..24e486f8050fe 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -27,7 +27,9 @@ use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::datatype::DataTypeExt; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::nested_struct::validate_struct_compatibility; +use datafusion_common::nested_struct::{ + requires_nested_struct_cast, validate_data_type_compatibility, +}; use datafusion_common::{Result, not_impl_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -43,20 +45,14 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; -/// Check if struct-to-struct casting is allowed by validating field compatibility. +/// Check if name-based struct casting is allowed by validating field compatibility. /// /// This function applies the same validation rules as execution time to ensure /// planning-time validation matches runtime validation, enabling fail-fast behavior -/// instead of deferring errors to execution. -fn can_cast_struct_types(source: &DataType, target: &DataType) -> bool { - match (source, target) { - (Struct(source_fields), Struct(target_fields)) => { - // Apply the same struct compatibility rules as at execution time. - // This ensures planning-time validation matches execution-time validation. - validate_struct_compatibility(source_fields, target_fields).is_ok() - } - _ => false, - } +/// instead of deferring errors to execution. Handles structs at any nesting level +/// (e.g., `List`, `Dictionary<_, Struct>`). +fn can_cast_named_struct_types(source: &DataType, target: &DataType) -> bool { + validate_data_type_compatibility("", source, target).is_ok() } /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast @@ -323,12 +319,13 @@ pub fn cast_with_options( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(Arc::clone(&expr)) - } else if matches!((&expr_type, &cast_type), (Struct(_), Struct(_))) { - if can_cast_struct_types(&expr_type, &cast_type) { - // Allow struct-to-struct casts that pass name-based compatibility validation. - // This validation is applied at planning time (now) to fail fast, rather than - // deferring errors to execution time. The name-based casting logic will be - // executed at runtime via ColumnarValue::cast_to. + } else if requires_nested_struct_cast(&expr_type, &cast_type) { + if can_cast_named_struct_types(&expr_type, &cast_type) { + // Allow casts involving structs (including nested inside Lists, Dictionaries, + // etc.) that pass name-based compatibility validation. This validation is + // applied at planning time (now) to fail fast, rather than deferring errors + // to execution time. The name-based casting logic will be executed at runtime + // via ColumnarValue::cast_to. Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs index a99953abdb5cb..158d467c08b59 100644 --- a/datafusion/physical-expr/src/expressions/cast_column.rs +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -140,15 +140,18 @@ impl PhysicalExpr for CastColumnExpr { let value = self.expr.evaluate(batch)?; match value { ColumnarValue::Array(array) => { - let casted = - cast_column(&array, self.target_field.as_ref(), &self.cast_options)?; + let casted = cast_column( + &array, + self.target_field.data_type(), + &self.cast_options, + )?; Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => { let as_array = scalar.to_array_of_size(1)?; let casted = cast_column( &as_array, - self.target_field.as_ref(), + self.target_field.data_type(), &self.cast_options, )?; let result = ScalarValue::try_from_array(casted.as_ref(), 0)?;