From 6a2a0c717ff6b19e11c2b0e2ae1afc6bc1d9f316 Mon Sep 17 00:00:00 2001 From: Alfonso Subiotto Marques Date: Thu, 12 Mar 2026 12:49:59 +0100 Subject: [PATCH] physical-expr-adapter: support casting structs nested inside complex types Previously, DefaultPhysicalExprAdapterRewriter.create_cast_column_expr would only allow struct evolution for top-level structs. However, it is valid to have structs that are nested inside other complex types (e.g. Dictionary<_, Struct>, List). These cases would previously return an error although they were valid. This commit handles Structs that are arbitrarily nested within certain complex types (Lists, ListViews, and Dicts). Other complex types can be supported in the future, but I think this is good enough for now. Signed-off-by: Alfonso Subiotto Marques --- datafusion/common/src/nested_struct.rs | 426 ++++++++++++++++-- datafusion/common/src/scalar/mod.rs | 24 +- datafusion/expr-common/src/columnar_value.rs | 32 +- .../src/schema_rewriter.rs | 176 ++++++-- .../physical-expr/src/expressions/cast.rs | 33 +- .../src/expressions/cast_column.rs | 9 +- 6 files changed, 569 insertions(+), 131 deletions(-) diff --git a/datafusion/common/src/nested_struct.rs b/datafusion/common/src/nested_struct.rs index bf2558f313069..cdd6215d08e2f 100644 --- a/datafusion/common/src/nested_struct.rs +++ b/datafusion/common/src/nested_struct.rs @@ -17,8 +17,11 @@ use crate::error::{_plan_err, Result}; use arrow::{ - array::{Array, ArrayRef, StructArray, new_null_array}, - compute::{CastOptions, cast_with_options}, + array::{ + Array, ArrayRef, DictionaryArray, GenericListArray, GenericListViewArray, + StructArray, downcast_integer, new_null_array, + }, + compute::{CastOptions, can_cast_types, cast_with_options}, datatypes::{DataType, DataType::Struct, Field, FieldRef}, }; use std::{collections::HashSet, sync::Arc}; @@ -80,14 +83,17 @@ fn cast_struct_column( match source_child_opt { Some(source_child_col) => { - let adapted_child = - cast_column(source_child_col, target_child_field, cast_options) - .map_err(|e| { - e.context(format!( - "While casting struct field '{}'", - target_child_field.name() - )) - })?; + let adapted_child = cast_column( + source_child_col, + target_child_field.data_type(), + cast_options, + ) + .map_err(|e| { + e.context(format!( + "While casting struct field '{}'", + target_child_field.name() + )) + })?; arrays.push(adapted_child); } None => { @@ -127,18 +133,17 @@ fn cast_struct_column( /// ``` /// use arrow::array::{ArrayRef, Int64Array}; /// use arrow::compute::CastOptions; -/// use arrow::datatypes::{DataType, Field}; +/// use arrow::datatypes::DataType; /// use datafusion_common::nested_struct::cast_column; /// use std::sync::Arc; /// /// let source: ArrayRef = Arc::new(Int64Array::from(vec![1, i64::MAX])); -/// let target = Field::new("ints", DataType::Int32, true); /// // Permit lossy conversions by producing NULL on overflow instead of erroring /// let options = CastOptions { /// safe: true, /// ..Default::default() /// }; -/// let result = cast_column(&source, &target, &options).unwrap(); +/// let result = cast_column(&source, &DataType::Int32, &options).unwrap(); /// assert!(result.is_null(1)); /// ``` /// @@ -151,7 +156,7 @@ fn cast_struct_column( /// /// # Arguments /// * `source_col` - The source array to cast -/// * `target_field` - The target field definition (including type and metadata) +/// * `target_type` - The target data type to cast to /// * `cast_options` - Options that govern strictness and formatting of the cast /// /// # Returns @@ -165,18 +170,139 @@ fn cast_struct_column( /// - Invalid data type combinations are encountered pub fn cast_column( source_col: &ArrayRef, - target_field: &Field, + target_type: &DataType, cast_options: &CastOptions, ) -> Result { - match target_field.data_type() { - Struct(target_fields) => { + match (source_col.data_type(), target_type) { + (_, Struct(target_fields)) => { cast_struct_column(source_col, target_fields, cast_options) } - _ => Ok(cast_with_options( + (DataType::List(_), DataType::List(target_inner)) => { + cast_list_column::(source_col, target_inner, cast_options) + } + (DataType::LargeList(_), DataType::LargeList(target_inner)) => { + cast_list_column::(source_col, target_inner, cast_options) + } + (DataType::ListView(_), DataType::ListView(target_inner)) => { + cast_list_view_column::(source_col, target_inner, cast_options) + } + (DataType::LargeListView(_), DataType::LargeListView(target_inner)) => { + cast_list_view_column::(source_col, target_inner, cast_options) + } + ( + DataType::Dictionary(source_key_type, _), + DataType::Dictionary(target_key_type, target_value_type), + ) => cast_dictionary_column( source_col, - target_field.data_type(), + source_key_type, + target_key_type, + target_value_type, cast_options, - )?), + ), + _ => Ok(cast_with_options(source_col, target_type, cast_options)?), + } +} + +fn cast_list_column( + source_col: &ArrayRef, + target_inner_field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let source_list = source_col + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + crate::error::DataFusionError::Plan(format!( + "Expected list array but got {}", + source_col.data_type() + )) + })?; + + let cast_values = cast_column( + source_list.values(), + target_inner_field.data_type(), + cast_options, + )?; + + let result = GenericListArray::::new( + Arc::clone(target_inner_field), + source_list.offsets().clone(), + cast_values, + source_list.nulls().cloned(), + ); + Ok(Arc::new(result)) +} + +fn cast_list_view_column( + source_col: &ArrayRef, + target_inner_field: &FieldRef, + cast_options: &CastOptions, +) -> Result { + let source_list = source_col + .as_any() + .downcast_ref::>() + .ok_or_else(|| { + crate::error::DataFusionError::Plan(format!( + "Expected list view array but got {}", + source_col.data_type() + )) + })?; + + let cast_values = cast_column( + source_list.values(), + target_inner_field.data_type(), + cast_options, + )?; + + let result = GenericListViewArray::::try_new( + Arc::clone(target_inner_field), + source_list.offsets().clone(), + source_list.sizes().clone(), + cast_values, + source_list.nulls().cloned(), + )?; + Ok(Arc::new(result)) +} + +fn cast_dictionary_column( + source_col: &ArrayRef, + source_key_type: &DataType, + target_key_type: &DataType, + target_value_type: &DataType, + cast_options: &CastOptions, +) -> Result { + // Dispatch on source key type to access keys/values, then recursively + // cast values. Rebuild with the source key type first. + macro_rules! cast_dict_values { + ($t:ty) => {{ + let source_dict = source_col + .as_any() + .downcast_ref::>() + .expect("downcast must succeed"); + let cast_values = + cast_column(source_dict.values(), target_value_type, cast_options)?; + Ok(Arc::new(DictionaryArray::<$t>::new( + source_dict.keys().clone(), + cast_values, + )) as ArrayRef) + }}; + } + + let result: Result = downcast_integer! { + source_key_type => (cast_dict_values), + k => _plan_err!("Unsupported dictionary key type: {k}") + }; + let result = result?; + + // If key types differ, delegate key casting to Arrow. + if source_key_type != target_key_type { + let target_dict_type = DataType::Dictionary( + Box::new(target_key_type.clone()), + Box::new(target_value_type.clone()), + ); + Ok(cast_with_options(&result, &target_dict_type, cast_options)?) + } else { + Ok(result) } } @@ -281,31 +407,84 @@ fn validate_field_compatibility( ); } - // Check if the matching field types are compatible - match (source_field.data_type(), target_field.data_type()) { - // Recursively validate nested structs + validate_data_type_compatibility( + target_field.name(), + source_field.data_type(), + target_field.data_type(), + ) +} + +/// Validates that `source_type` can be cast to `target_type`, recursively +/// handling container types that wrap structs. +pub fn validate_data_type_compatibility( + field_name: &str, + source_type: &DataType, + target_type: &DataType, +) -> Result<()> { + match (source_type, target_type) { (Struct(source_nested), Struct(target_nested)) => { validate_struct_compatibility(source_nested, target_nested)?; } - // For non-struct types, use the existing castability check + (DataType::List(s), DataType::List(t)) + | (DataType::LargeList(s), DataType::LargeList(t)) + | (DataType::ListView(s), DataType::ListView(t)) + | (DataType::LargeListView(s), DataType::LargeListView(t)) => { + validate_field_compatibility(s, t)?; + } + (DataType::Dictionary(s_key, s_val), DataType::Dictionary(t_key, t_val)) => { + if !can_cast_types(s_key, t_key) { + return _plan_err!( + "Cannot cast dictionary key type {} to {} for field '{}'", + s_key, + t_key, + field_name + ); + } + validate_data_type_compatibility(field_name, s_val, t_val)?; + } _ => { - if !arrow::compute::can_cast_types( - source_field.data_type(), - target_field.data_type(), - ) { + if !can_cast_types(source_type, target_type) { return _plan_err!( "Cannot cast struct field '{}' from type {} to type {}", - target_field.name(), - source_field.data_type(), - target_field.data_type() + field_name, + source_type, + target_type ); } } } - Ok(()) } +/// Returns true if casting from `source_type` to `target_type` requires +/// name-based nested struct casting logic, rather than Arrow's standard cast. +/// +/// This is the case when both types are struct types, or both are the same +/// container type (List, LargeList, ListView, LargeListView, Dictionary) wrapping +/// types that recursively contain structs. +/// +/// Use this predicate at both planning time (to decide whether to apply struct +/// compatibility validation) and execution time (to decide whether to route +/// through [`cast_column`] instead of Arrow's generic cast). +pub fn requires_nested_struct_cast( + source_type: &DataType, + target_type: &DataType, +) -> bool { + match (source_type, target_type) { + (Struct(_), Struct(_)) => true, + (DataType::List(s), DataType::List(t)) + | (DataType::LargeList(s), DataType::LargeList(t)) + | (DataType::ListView(s), DataType::ListView(t)) + | (DataType::LargeListView(s), DataType::LargeListView(t)) => { + requires_nested_struct_cast(s.data_type(), t.data_type()) + } + (DataType::Dictionary(_, s_val), DataType::Dictionary(_, t_val)) => { + requires_nested_struct_cast(s_val, t_val) + } + _ => false, + } +} + /// Check if two field lists have at least one common field by name. /// /// This is useful for validating struct compatibility when casting between structs, @@ -325,15 +504,14 @@ pub fn has_one_of_more_common_fields( #[cfg(test)] mod tests { - use super::*; use crate::{assert_contains, format::DEFAULT_CAST_OPTIONS}; use arrow::{ array::{ - BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, MapArray, - MapBuilder, NullArray, StringArray, StringBuilder, + BinaryArray, Int32Array, Int32Builder, Int64Array, ListArray, ListViewArray, + MapArray, MapBuilder, NullArray, StringArray, StringBuilder, }, - buffer::NullBuffer, + buffer::{NullBuffer, ScalarBuffer}, datatypes::{DataType, Field, FieldRef, Int32Type}, }; /// Macro to extract and downcast a column from a StructArray @@ -376,7 +554,9 @@ mod tests { fn test_cast_simple_column() { let source = Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef; let target_field = field("ints", DataType::Int64); - let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + let result = + cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.len(), 3); assert_eq!(result.value(0), 1); @@ -394,14 +574,15 @@ mod tests { safe: false, ..DEFAULT_CAST_OPTIONS }; - assert!(cast_column(&source, &target_field, &safe_opts).is_err()); + assert!(cast_column(&source, target_field.data_type(), &safe_opts).is_err()); let unsafe_opts = CastOptions { // safe: true - return Null for failure safe: true, ..DEFAULT_CAST_OPTIONS }; - let result = cast_column(&source, &target_field, &unsafe_opts).unwrap(); + let result = + cast_column(&source, target_field.data_type(), &unsafe_opts).unwrap(); let result = result.as_any().downcast_ref::().unwrap(); assert_eq!(result.value(0), 1); assert!(result.is_null(1)); @@ -422,7 +603,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(struct_array.fields().len(), 2); let a_result = get_column_as!(&struct_array, "a", Int32Array); @@ -440,7 +622,8 @@ mod tests { let source = Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef; let target_field = struct_field("s", vec![field("a", DataType::Int32)]); - let result = cast_column(&source, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Cannot cast column of type")); @@ -460,7 +643,8 @@ mod tests { let target_field = struct_field("s", vec![field("a", DataType::Int32)]); - let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert!(error_msg.contains("Cannot cast struct field 'a'")); @@ -557,7 +741,8 @@ mod tests { let target_field = struct_field("s", vec![field("a", DataType::Int64)]); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); assert_eq!(struct_array.null_count(), 1); assert!(struct_array.is_valid(0)); @@ -767,7 +952,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let outer = result.as_any().downcast_ref::().unwrap(); let inner = get_column_as!(&outer, "inner", StructArray); assert_eq!(inner.fields().len(), 3); @@ -802,7 +988,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let outer = result.as_any().downcast_ref::().unwrap(); let inner = get_column_as!(&outer, "inner", StructArray); assert_eq!(inner.len(), 2); @@ -883,7 +1070,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let arr = get_column_as!(&struct_array, "arr", ListArray); @@ -922,7 +1110,8 @@ mod tests { ); let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let b_col = get_column_as!(&struct_array, "b", Int64Array); @@ -951,7 +1140,8 @@ mod tests { vec![field("a", DataType::Int64), field("b", DataType::Utf8)], ); - let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let error_msg = result.unwrap_err().to_string(); assert_contains!(error_msg, "no field name overlap"); @@ -974,7 +1164,8 @@ mod tests { ); // Should fail because 'b' is non-nullable but missing from source - let result = cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS); + let result = + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS); assert!(result.is_err()); let err = result.unwrap_err(); assert!( @@ -999,7 +1190,8 @@ mod tests { // Should succeed - 'b' is nullable so can be filled with NULL let result = - cast_column(&source_col, &target_field, &DEFAULT_CAST_OPTIONS).unwrap(); + cast_column(&source_col, target_field.data_type(), &DEFAULT_CAST_OPTIONS) + .unwrap(); let struct_array = result.as_any().downcast_ref::().unwrap(); let a_col = get_column_as!(&struct_array, "a", Int32Array); @@ -1010,4 +1202,138 @@ mod tests { assert!(b_col.is_null(0)); assert!(b_col.is_null(1)); } + + #[test] + fn test_validate_dictionary_value_evolution() { + let source_inner = struct_type(vec![field("a", DataType::Int32)]); + let target_inner = struct_type(vec![ + field("a", DataType::Int32), + field("b", DataType::Utf8), + ]); + let source = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(source_inner)); + let target = + DataType::Dictionary(Box::new(DataType::Int32), Box::new(target_inner)); + assert!(validate_data_type_compatibility("col", &source, &target).is_ok()); + } + + #[test] + fn test_cast_dictionary_struct_value() { + // Build a Dictionary and cast to + // Dictionary (field added, type widened). + let struct_arr = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::new(Int32Array::from(vec![10, 20])) as ArrayRef, + )]); + // keys: [0, null, 1] mapping into the 2-element struct values array. + let keys = Int32Array::from(vec![Some(0), None, Some(1)]); + let source_dict = DictionaryArray::::new(keys, Arc::new(struct_arr)); + let source_col: ArrayRef = Arc::new(source_dict); + + let target_type = DataType::Dictionary( + Box::new(DataType::Int32), + Box::new(struct_type(vec![ + field("a", DataType::Int64), + field("b", DataType::Utf8), + ])), + ); + + let result = + cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap(); + let result_dict = result + .as_any() + .downcast_ref::>() + .unwrap(); + + assert!(result_dict.is_valid(0)); + assert!(result_dict.is_null(1)); + assert!(result_dict.is_valid(2)); + + let struct_values = result_dict + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let a_col = get_column_as!(&struct_values, "a", Int64Array); + assert_eq!(a_col.values(), &[10, 20]); + let b_col = get_column_as!(&struct_values, "b", StringArray); + assert!(b_col.iter().all(|v| v.is_none())); + } + + #[test] + fn test_cast_list_view_struct() { + // Build a ListView and cast to + // ListView. + let struct_arr = StructArray::from(vec![( + arc_field("a", DataType::Int32), + Arc::new(Int32Array::from(vec![1, 2, 3])) as ArrayRef, + )]); + + let source_field = + arc_field("item", struct_type(vec![field("a", DataType::Int32)])); + let target_field = arc_field( + "item", + struct_type(vec![ + field("a", DataType::Int64), + field("b", DataType::Utf8), + ]), + ); + + // Two list-view entries: [0..2] and [2..3] + let list_view = ListViewArray::new( + source_field, + ScalarBuffer::from(vec![0i32, 2]), + ScalarBuffer::from(vec![2i32, 1]), + Arc::new(struct_arr), + None, + ); + let source_col: ArrayRef = Arc::new(list_view); + + let target_type = DataType::ListView(target_field); + + let result = + cast_column(&source_col, &target_type, &DEFAULT_CAST_OPTIONS).unwrap(); + let result_lv = result.as_any().downcast_ref::().unwrap(); + assert_eq!(result_lv.len(), 2); + + let struct_values = result_lv + .values() + .as_any() + .downcast_ref::() + .unwrap(); + let a_col = get_column_as!(&struct_values, "a", Int64Array); + assert_eq!(a_col.values(), &[1, 2, 3]); + let b_col = get_column_as!(&struct_values, "b", StringArray); + assert!(b_col.iter().all(|v| v.is_none())); + } + + #[test] + fn test_requires_nested_struct_cast() { + let s1 = struct_type(vec![field("a", DataType::Int32)]); + let s2 = struct_type(vec![field("a", DataType::Int64)]); + + assert!(requires_nested_struct_cast(&s1, &s2)); + assert!(requires_nested_struct_cast( + &DataType::List(arc_field("item", s1.clone())), + &DataType::List(arc_field("item", s2.clone())), + )); + assert!(requires_nested_struct_cast( + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s1.clone())), + &DataType::Dictionary(Box::new(DataType::Int32), Box::new(s2.clone())), + )); + assert!(requires_nested_struct_cast( + &DataType::ListView(arc_field("item", s1)), + &DataType::ListView(arc_field("item", s2)), + )); + + // Non-struct types should return false. + assert!(!requires_nested_struct_cast( + &DataType::Int32, + &DataType::Int64 + )); + assert!(!requires_nested_struct_cast( + &DataType::List(arc_field("item", DataType::Int32)), + &DataType::List(arc_field("item", DataType::Int64)), + )); + } } diff --git a/datafusion/common/src/scalar/mod.rs b/datafusion/common/src/scalar/mod.rs index d759bbedd9266..2023339a33687 100644 --- a/datafusion/common/src/scalar/mod.rs +++ b/datafusion/common/src/scalar/mod.rs @@ -3901,20 +3901,16 @@ impl ScalarValue { let scalar_array = self.to_array()?; - // For struct types, use name-based casting logic that matches fields by name - // and recursively casts nested structs. The field name wrapper is arbitrary - // since cast_column only uses the DataType::Struct field definitions inside. - let cast_arr = match target_type { - DataType::Struct(_) => { - // Field name is unused; only the struct's inner field names matter - let target_field = Field::new("_", target_type.clone(), true); - crate::nested_struct::cast_column( - &scalar_array, - &target_field, - cast_options, - )? - } - _ => cast_with_options(&scalar_array, target_type, cast_options)?, + // For types that contain structs (including nested inside Lists, Dictionaries, + // etc.), use name-based casting logic that matches struct fields by name and + // recursively casts nested structs. + let cast_arr = if crate::nested_struct::requires_nested_struct_cast( + scalar_array.data_type(), + target_type, + ) { + crate::nested_struct::cast_column(&scalar_array, target_type, cast_options)? + } else { + cast_with_options(&scalar_array, target_type, cast_options)? }; ScalarValue::try_from_array(&cast_arr, 0) diff --git a/datafusion/expr-common/src/columnar_value.rs b/datafusion/expr-common/src/columnar_value.rs index 1aa42470a1481..bc6b8177ab3cf 100644 --- a/datafusion/expr-common/src/columnar_value.rs +++ b/datafusion/expr-common/src/columnar_value.rs @@ -20,7 +20,7 @@ use arrow::{ array::{Array, ArrayRef, Date32Array, Date64Array, NullArray}, compute::{CastOptions, kernels, max, min}, - datatypes::{DataType, Field}, + datatypes::DataType, util::pretty::pretty_format_columns, }; use datafusion_common::internal_datafusion_err; @@ -313,24 +313,18 @@ fn cast_array_by_name( return Ok(Arc::clone(array)); } - match cast_type { - DataType::Struct(_) => { - // Field name is unused; only the struct's inner field names matter - let target_field = Field::new("_", cast_type.clone(), true); - datafusion_common::nested_struct::cast_column( - array, - &target_field, - cast_options, - ) - } - _ => { - ensure_date_array_timestamp_bounds(array, cast_type)?; - Ok(kernels::cast::cast_with_options( - array, - cast_type, - cast_options, - )?) - } + if datafusion_common::nested_struct::requires_nested_struct_cast( + array.data_type(), + cast_type, + ) { + datafusion_common::nested_struct::cast_column(array, cast_type, cast_options) + } else { + ensure_date_array_timestamp_bounds(array, cast_type)?; + Ok(kernels::cast::cast_with_options( + array, + cast_type, + cast_options, + )?) } } diff --git a/datafusion/physical-expr-adapter/src/schema_rewriter.rs b/datafusion/physical-expr-adapter/src/schema_rewriter.rs index a2a45cbdfe7aa..f2677dbe3f0cd 100644 --- a/datafusion/physical-expr-adapter/src/schema_rewriter.rs +++ b/datafusion/physical-expr-adapter/src/schema_rewriter.rs @@ -25,12 +25,11 @@ use std::hash::Hash; use std::sync::Arc; use arrow::array::RecordBatch; -use arrow::compute::can_cast_types; use arrow::datatypes::{DataType, Field, FieldRef, SchemaRef}; use datafusion_common::{ - Result, ScalarValue, exec_err, + DataFusionError, Result, ScalarValue, exec_err, metadata::FieldMetadata, - nested_struct::validate_struct_compatibility, + nested_struct::validate_data_type_compatibility, tree_node::{Transformed, TransformedResult, TreeNode}, }; use datafusion_functions::core::getfield::GetFieldFunc; @@ -487,31 +486,18 @@ impl DefaultPhysicalExprAdapterRewriter { physical_field: FieldRef, logical_field: &Field, ) -> Result>> { - // For struct types, use validate_struct_compatibility which handles: - // - Missing fields in source (filled with nulls) - // - Extra fields in source (ignored) - // - Recursive validation of nested structs - // For non-struct types, use Arrow's can_cast_types - match (physical_field.data_type(), logical_field.data_type()) { - (DataType::Struct(physical_fields), DataType::Struct(logical_fields)) => { - validate_struct_compatibility( - physical_fields.as_ref(), - logical_fields.as_ref(), - )?; - } - _ => { - let is_compatible = - can_cast_types(physical_field.data_type(), logical_field.data_type()); - if !is_compatible { - return exec_err!( - "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type)", + validate_data_type_compatibility( + column.name(), + physical_field.data_type(), + logical_field.data_type(), + ) + .map_err(|e| + DataFusionError::Execution(format!( + "Cannot cast column '{}' from '{}' (physical data type) to '{}' (logical data type): {e}", column.name(), physical_field.data_type(), logical_field.data_type() - ); - } - } - } + )))?; let cast_expr = Arc::new(CastColumnExpr::new( Arc::new(column), @@ -663,8 +649,8 @@ impl BatchAdapter { mod tests { use super::*; use arrow::array::{ - BooleanArray, Int32Array, Int64Array, RecordBatch, RecordBatchOptions, - StringArray, StringViewArray, StructArray, + Array, BooleanArray, GenericListArray, Int32Array, Int64Array, RecordBatch, + RecordBatchOptions, StringArray, StringViewArray, StructArray, }; use arrow::datatypes::{DataType, Field, Fields, Schema, SchemaRef}; use datafusion_common::{Result, ScalarValue, assert_contains, record_batch}; @@ -1289,6 +1275,142 @@ mod tests { assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]); } + /// Test that List columns are properly adapted with struct evolution. + #[test] + fn test_adapt_list_struct_batches() { + // Physical: List<{id: Int32, name: Utf8}> + let physical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int32, false), + Field::new("name", DataType::Utf8, true), + ] + .into(); + + let struct_array = StructArray::new( + physical_struct_fields.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 3])) as _, + Arc::new(StringArray::from(vec![ + Some("alice"), + None, + Some("charlie"), + ])) as _, + ], + None, + ); + + // One list element per row + let item_field = Arc::new(Field::new( + "item", + DataType::Struct(physical_struct_fields.clone()), + true, + )); + let offsets = + arrow::buffer::OffsetBuffer::from_lengths(vec![1usize; struct_array.len()]); + let list_array = GenericListArray::::new( + item_field, + offsets, + Arc::new(struct_array), + None, + ); + + let physical_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(physical_struct_fields), + true, + ))), + false, + )])); + + let physical_batch = RecordBatch::try_new( + Arc::clone(&physical_schema), + vec![Arc::new(list_array)], + ) + .unwrap(); + + // Logical: List<{id: Int64, name: Utf8View, extra: Boolean}> + let logical_struct_fields: Fields = vec![ + Field::new("id", DataType::Int64, false), + Field::new("name", DataType::Utf8View, true), + Field::new("extra", DataType::Boolean, true), + ] + .into(); + + let logical_schema = Arc::new(Schema::new(vec![Field::new( + "data", + DataType::List(Arc::new(Field::new( + "item", + DataType::Struct(logical_struct_fields.clone()), + true, + ))), + false, + )])); + + let projection = vec![col("data", &logical_schema).unwrap()]; + + let factory = DefaultPhysicalExprAdapterFactory; + let adapter = factory + .create(Arc::clone(&logical_schema), Arc::clone(&physical_schema)) + .unwrap(); + + let adapted_projection = projection + .into_iter() + .map(|expr| adapter.rewrite(expr).unwrap()) + .collect_vec(); + + let adapted_schema = Arc::new(Schema::new( + adapted_projection + .iter() + .map(|expr| expr.return_field(&physical_schema).unwrap()) + .collect_vec(), + )); + + let res = batch_project( + adapted_projection, + &physical_batch, + Arc::clone(&adapted_schema), + ) + .unwrap(); + + assert_eq!(res.num_columns(), 1); + + let result_list = res + .column(0) + .as_any() + .downcast_ref::>() + .unwrap(); + + // Check each list element contains the evolved struct + assert_eq!(result_list.len(), 3); + let flat_structs = result_list + .values() + .as_any() + .downcast_ref::() + .unwrap(); + + let id_col = flat_structs.column_by_name("id").unwrap(); + assert_eq!(id_col.data_type(), &DataType::Int64); + let id_values = id_col.as_any().downcast_ref::().unwrap(); + assert_eq!( + id_values.iter().collect_vec(), + vec![Some(1), Some(2), Some(3)] + ); + + let name_col = flat_structs.column_by_name("name").unwrap(); + assert_eq!(name_col.data_type(), &DataType::Utf8View); + let name_values = name_col.as_any().downcast_ref::().unwrap(); + assert_eq!( + name_values.iter().collect_vec(), + vec![Some("alice"), None, Some("charlie")] + ); + + let extra_col = flat_structs.column_by_name("extra").unwrap(); + assert_eq!(extra_col.data_type(), &DataType::Boolean); + let extra_values = extra_col.as_any().downcast_ref::().unwrap(); + assert_eq!(extra_values.iter().collect_vec(), vec![None, None, None]); + } + #[test] fn test_try_rewrite_struct_field_access() { // Test the core logic of try_rewrite_struct_field_access diff --git a/datafusion/physical-expr/src/expressions/cast.rs b/datafusion/physical-expr/src/expressions/cast.rs index 5a80daf663a36..24e486f8050fe 100644 --- a/datafusion/physical-expr/src/expressions/cast.rs +++ b/datafusion/physical-expr/src/expressions/cast.rs @@ -27,7 +27,9 @@ use arrow::datatypes::{DataType, DataType::*, FieldRef, Schema}; use arrow::record_batch::RecordBatch; use datafusion_common::datatype::DataTypeExt; use datafusion_common::format::DEFAULT_FORMAT_OPTIONS; -use datafusion_common::nested_struct::validate_struct_compatibility; +use datafusion_common::nested_struct::{ + requires_nested_struct_cast, validate_data_type_compatibility, +}; use datafusion_common::{Result, not_impl_err}; use datafusion_expr_common::columnar_value::ColumnarValue; use datafusion_expr_common::interval_arithmetic::Interval; @@ -43,20 +45,14 @@ const DEFAULT_SAFE_CAST_OPTIONS: CastOptions<'static> = CastOptions { format_options: DEFAULT_FORMAT_OPTIONS, }; -/// Check if struct-to-struct casting is allowed by validating field compatibility. +/// Check if name-based struct casting is allowed by validating field compatibility. /// /// This function applies the same validation rules as execution time to ensure /// planning-time validation matches runtime validation, enabling fail-fast behavior -/// instead of deferring errors to execution. -fn can_cast_struct_types(source: &DataType, target: &DataType) -> bool { - match (source, target) { - (Struct(source_fields), Struct(target_fields)) => { - // Apply the same struct compatibility rules as at execution time. - // This ensures planning-time validation matches execution-time validation. - validate_struct_compatibility(source_fields, target_fields).is_ok() - } - _ => false, - } +/// instead of deferring errors to execution. Handles structs at any nesting level +/// (e.g., `List`, `Dictionary<_, Struct>`). +fn can_cast_named_struct_types(source: &DataType, target: &DataType) -> bool { + validate_data_type_compatibility("", source, target).is_ok() } /// CAST expression casts an expression to a specific data type and returns a runtime error on invalid cast @@ -323,12 +319,13 @@ pub fn cast_with_options( let expr_type = expr.data_type(input_schema)?; if expr_type == cast_type { Ok(Arc::clone(&expr)) - } else if matches!((&expr_type, &cast_type), (Struct(_), Struct(_))) { - if can_cast_struct_types(&expr_type, &cast_type) { - // Allow struct-to-struct casts that pass name-based compatibility validation. - // This validation is applied at planning time (now) to fail fast, rather than - // deferring errors to execution time. The name-based casting logic will be - // executed at runtime via ColumnarValue::cast_to. + } else if requires_nested_struct_cast(&expr_type, &cast_type) { + if can_cast_named_struct_types(&expr_type, &cast_type) { + // Allow casts involving structs (including nested inside Lists, Dictionaries, + // etc.) that pass name-based compatibility validation. This validation is + // applied at planning time (now) to fail fast, rather than deferring errors + // to execution time. The name-based casting logic will be executed at runtime + // via ColumnarValue::cast_to. Ok(Arc::new(CastExpr::new(expr, cast_type, cast_options))) } else { not_impl_err!("Unsupported CAST from {expr_type} to {cast_type}") diff --git a/datafusion/physical-expr/src/expressions/cast_column.rs b/datafusion/physical-expr/src/expressions/cast_column.rs index a99953abdb5cb..158d467c08b59 100644 --- a/datafusion/physical-expr/src/expressions/cast_column.rs +++ b/datafusion/physical-expr/src/expressions/cast_column.rs @@ -140,15 +140,18 @@ impl PhysicalExpr for CastColumnExpr { let value = self.expr.evaluate(batch)?; match value { ColumnarValue::Array(array) => { - let casted = - cast_column(&array, self.target_field.as_ref(), &self.cast_options)?; + let casted = cast_column( + &array, + self.target_field.data_type(), + &self.cast_options, + )?; Ok(ColumnarValue::Array(casted)) } ColumnarValue::Scalar(scalar) => { let as_array = scalar.to_array_of_size(1)?; let casted = cast_column( &as_array, - self.target_field.as_ref(), + self.target_field.data_type(), &self.cast_options, )?; let result = ScalarValue::try_from_array(casted.as_ref(), 0)?;