diff --git a/datafusion/expr/src/logical_plan/invariants.rs b/datafusion/expr/src/logical_plan/invariants.rs
index b39b23e30f4e8..8f5f5c0454e20 100644
--- a/datafusion/expr/src/logical_plan/invariants.rs
+++ b/datafusion/expr/src/logical_plan/invariants.rs
@@ -222,14 +222,31 @@ pub fn check_subquery_expr(
check_correlations_in_subquery(inner_plan)
} else {
if let Expr::InSubquery(subquery) = expr {
- // InSubquery should only return one column
- if subquery.subquery.subquery.schema().fields().len() > 1 {
+ // InSubquery should only return one column UNLESS the left expression is a struct
+ // (multi-column IN like: (a, b) NOT IN (SELECT x, y FROM ...))
+ let is_struct = matches!(*subquery.expr, Expr::ScalarFunction(ref func) if func.func.name() == "struct");
+
+ let num_subquery_cols = subquery.subquery.subquery.schema().fields().len();
+
+ if !is_struct && num_subquery_cols > 1 {
return plan_err!(
"InSubquery should only return one column, but found {}: {}",
- subquery.subquery.subquery.schema().fields().len(),
+ num_subquery_cols,
subquery.subquery.subquery.schema().field_names().join(", ")
);
}
+
+ // For struct expressions, validate that the number of fields matches
+ if is_struct && let Expr::ScalarFunction(ref func) = *subquery.expr {
+ let num_tuple_cols = func.args.len();
+ if num_tuple_cols != num_subquery_cols {
+ return plan_err!(
+ "The number of columns in the tuple ({}) must match the number of columns in the subquery ({})",
+ num_tuple_cols,
+ num_subquery_cols
+ );
+ }
+ }
}
if let Expr::SetComparison(set_comparison) = expr
&& set_comparison.subquery.subquery.schema().fields().len() > 1
diff --git a/datafusion/optimizer/src/analyzer/type_coercion.rs b/datafusion/optimizer/src/analyzer/type_coercion.rs
index a98678f7cf9c4..7a54404cab272 100644
--- a/datafusion/optimizer/src/analyzer/type_coercion.rs
+++ b/datafusion/optimizer/src/analyzer/type_coercion.rs
@@ -482,23 +482,43 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
Arc::unwrap_or_clone(subquery.subquery),
)?
.data;
- let expr_type = expr.get_type(self.schema)?;
- let subquery_type = new_plan.schema().field(0).data_type();
- let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
- plan_datafusion_err!(
- "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
- ),
- )?;
- let new_subquery = Subquery {
- subquery: Arc::new(new_plan),
- outer_ref_columns: subquery.outer_ref_columns,
- spans: subquery.spans,
- };
- Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
- Box::new(expr.cast_to(&common_type, self.schema)?),
- cast_subquery(new_subquery, &common_type)?,
- negated,
- ))))
+
+ // Check if this is a multi-column IN (struct expression)
+ let is_struct = matches!(*expr, Expr::ScalarFunction(ref func) if func.func.name() == "struct");
+
+ if is_struct {
+ // For multi-column IN, we don't need type coercion at this level
+ // The decorrelation phase will handle this by creating join conditions
+ let new_subquery = Subquery {
+ subquery: Arc::new(new_plan),
+ outer_ref_columns: subquery.outer_ref_columns,
+ spans: subquery.spans,
+ };
+ Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
+ expr,
+ new_subquery,
+ negated,
+ ))))
+ } else {
+ // Single-column IN: apply type coercion as before
+ let expr_type = expr.get_type(self.schema)?;
+ let subquery_type = new_plan.schema().field(0).data_type();
+ let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
+ plan_datafusion_err!(
+ "expr type {expr_type} can't cast to {subquery_type} in InSubquery"
+ ),
+ )?;
+ let new_subquery = Subquery {
+ subquery: Arc::new(new_plan),
+ outer_ref_columns: subquery.outer_ref_columns,
+ spans: subquery.spans,
+ };
+ Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
+ Box::new(expr.cast_to(&common_type, self.schema)?),
+ cast_subquery(new_subquery, &common_type)?,
+ negated,
+ ))))
+ }
}
Expr::SetComparison(SetComparison {
expr,
diff --git a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
index b9d160d55589f..552f8b6aeceb1 100644
--- a/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
+++ b/datafusion/optimizer/src/decorrelate_predicate_subquery.rs
@@ -386,9 +386,57 @@ fn build_join(
right,
})),
) => {
- let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
- let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
- in_predicate.and(join_filter)
+ // Check if this is a multi-column IN (struct expression)
+ if let Expr::ScalarFunction(func) = left.deref() {
+ if func.func.name() == "struct" {
+ // Decompose struct into individual field comparisons
+ let struct_args = &func.args;
+
+ // The right side should be the subquery result
+ // Note: After pull-up, the subquery may have additional correlated columns
+ // We only care about the first N columns that match our struct fields
+ let subquery_fields = sub_query_alias.schema().fields();
+
+ if struct_args.len() > subquery_fields.len() {
+ return plan_err!(
+ "Struct field count ({}) exceeds subquery column count ({})",
+ struct_args.len(),
+ subquery_fields.len()
+ );
+ }
+
+ // Create equality conditions for each field
+ let mut conditions = Vec::new();
+ for (i, arg) in struct_args.iter().enumerate() {
+ let field = &subquery_fields[i];
+ let right_col = Expr::Column(Column::new(
+ Some(alias.clone()),
+ field.name().to_string(),
+ ));
+ conditions.push(Expr::eq(arg.clone(), right_col));
+ }
+
+ // Combine all conditions with AND
+ let in_predicate = conditions
+ .into_iter()
+ .reduce(|acc, cond| acc.and(cond))
+ .unwrap_or_else(|| lit(true));
+
+ in_predicate.and(join_filter)
+ } else {
+ // Regular scalar function, handle as before
+ let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
+ let in_predicate =
+ Expr::eq(left.deref().clone(), Expr::Column(right_col));
+ in_predicate.and(join_filter)
+ }
+ } else {
+ // Not a struct, handle as before
+ let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
+ let in_predicate =
+ Expr::eq(left.deref().clone(), Expr::Column(right_col));
+ in_predicate.and(join_filter)
+ }
}
(Some(join_filter), _) => join_filter,
(
@@ -399,9 +447,51 @@ fn build_join(
right,
})),
) => {
- let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
+ // Check if this is a multi-column IN (struct expression)
+ if let Expr::ScalarFunction(func) = left.deref() {
+ if func.func.name() == "struct" {
+ // Decompose struct into individual field comparisons
+ let struct_args = &func.args;
+
+ // The right side should be the subquery result
+ // Note: After pull-up, the subquery may have additional correlated columns
+ // We only care about the first N columns that match our struct fields
+ let subquery_fields = sub_query_alias.schema().fields();
+
+ if struct_args.len() > subquery_fields.len() {
+ return plan_err!(
+ "Struct field count ({}) exceeds subquery column count ({})",
+ struct_args.len(),
+ subquery_fields.len()
+ );
+ }
+
+ // Create equality conditions for each field
+ let mut conditions = Vec::new();
+ for (i, arg) in struct_args.iter().enumerate() {
+ let field = &subquery_fields[i];
+ let right_col = Expr::Column(Column::new(
+ Some(alias.clone()),
+ field.name().to_string(),
+ ));
+ conditions.push(Expr::eq(arg.clone(), right_col));
+ }
- Expr::eq(left.deref().clone(), Expr::Column(right_col))
+ // Combine all conditions with AND
+ conditions
+ .into_iter()
+ .reduce(|acc, cond| acc.and(cond))
+ .unwrap_or_else(|| lit(true))
+ } else {
+ // Regular scalar function, handle as before
+ let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
+ Expr::eq(left.deref().clone(), Expr::Column(right_col))
+ }
+ } else {
+ // Not a struct, handle as before
+ let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
+ Expr::eq(left.deref().clone(), Expr::Column(right_col))
+ }
}
(None, None) => lit(true),
_ => return Ok(None),
diff --git a/datafusion/physical-plan/src/joins/hash_join/exec.rs b/datafusion/physical-plan/src/joins/hash_join/exec.rs
index a330ad54cb338..9666896984982 100644
--- a/datafusion/physical-plan/src/joins/hash_join/exec.rs
+++ b/datafusion/physical-plan/src/joins/hash_join/exec.rs
@@ -339,18 +339,10 @@ impl HashJoinExecBuilder {
check_join_is_valid(&left_schema, &right_schema, &on)?;
// Validate null_aware flag
- if null_aware {
- if !matches!(join_type, JoinType::LeftAnti) {
- return plan_err!(
- "null_aware can only be true for LeftAnti joins, got {join_type}"
- );
- }
- if on.len() != 1 {
- return plan_err!(
- "null_aware anti join only supports single column join key, got {} columns",
- on.len()
- );
- }
+ if null_aware && !matches!(join_type, JoinType::LeftAnti) {
+ return plan_err!(
+ "null_aware can only be true for LeftAnti joins, got {join_type}"
+ );
}
let (join_schema, column_indices) =
@@ -5659,26 +5651,41 @@ mod tests {
);
}
- /// Test that null_aware validation rejects multi-column joins
+ /// Test null-aware anti join with multi-column (a, b) NOT IN (SELECT x, y FROM ...)
+ /// When probe side (right) has NULL in ANY column, result should be empty
+ #[apply(hash_join_exec_configs)]
#[tokio::test]
- async fn test_null_aware_validation_multi_column() {
- let left = build_table(("a", &vec![1]), ("b", &vec![2]), ("c", &vec![3]));
- let right = build_table(("x", &vec![1]), ("y", &vec![2]), ("z", &vec![3]));
+ async fn test_null_aware_anti_join_multi_column_probe_null(
+ batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = prepare_task_ctx(batch_size, false);
+
+ // Build left table (rows to potentially output)
+ let left = build_table_three_cols(
+ ("a", &vec![Some(1), Some(3), Some(5)]),
+ ("b", &vec![Some(2), Some(4), Some(6)]),
+ ("dummy", &vec![Some(10), Some(30), Some(50)]),
+ );
+
+ // Build right table (has NULL in second column)
+ let right = build_table_two_cols(
+ ("x", &vec![Some(1), Some(7)]),
+ ("y", &vec![Some(2), None]), // NULL in y column
+ );
- // Try multi-column join
let on = vec![
(
- Arc::new(Column::new_with_schema("a", &left.schema()).unwrap()) as _,
- Arc::new(Column::new_with_schema("x", &right.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("x", &right.schema())?) as _,
),
(
- Arc::new(Column::new_with_schema("b", &left.schema()).unwrap()) as _,
- Arc::new(Column::new_with_schema("y", &right.schema()).unwrap()) as _,
+ Arc::new(Column::new_with_schema("b", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("y", &right.schema())?) as _,
),
];
- // Try to create null-aware anti join with 2 columns (should fail)
- let result = HashJoinExec::try_new(
+ // Create null-aware anti join
+ let join = HashJoinExec::try_new(
left,
right,
on,
@@ -5687,15 +5694,165 @@ mod tests {
None,
PartitionMode::CollectLeft,
NullEquality::NullEqualsNothing,
- true, // null_aware = true (invalid for multi-column)
+ true, // null_aware = true
+ )?;
+
+ let stream = join.execute(0, task_ctx)?;
+ let batches = common::collect(stream).await?;
+
+ // Expected: empty result (probe side has NULL in y column, so no rows should be output)
+ allow_duplicates! {
+ assert_snapshot!(batches_to_sort_string(&batches), @r"
+ ++
+ ++
+ ");
+ }
+ Ok(())
+ }
+
+ /// Test null-aware anti join with multi-column when probe side has no NULLs
+ /// Expected: rows that don't match should be output, but rows with NULL keys should be filtered
+ #[apply(hash_join_exec_configs)]
+ #[tokio::test]
+ async fn test_null_aware_anti_join_multi_column_no_null(
+ batch_size: usize,
+ ) -> Result<()> {
+ let task_ctx = prepare_task_ctx(batch_size, false);
+
+ // Build left table with some NULL keys
+ let left = build_table_three_cols(
+ ("a", &vec![Some(1), Some(3), Some(5), None]),
+ ("b", &vec![Some(2), Some(4), Some(6), Some(8)]),
+ ("dummy", &vec![Some(10), Some(30), Some(50), Some(0)]),
);
- assert!(result.is_err());
- assert!(
- result
- .unwrap_err()
- .to_string()
- .contains("null_aware anti join only supports single column join key")
+ // Build right table (no NULLs)
+ let right = build_table_two_cols(("x", &vec![Some(1)]), ("y", &vec![Some(2)]));
+
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("x", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("y", &right.schema())?) as _,
+ ),
+ ];
+
+ // Create null-aware anti join
+ let join = HashJoinExec::try_new(
+ left,
+ right,
+ on,
+ None,
+ &JoinType::LeftAnti,
+ None,
+ PartitionMode::CollectLeft,
+ NullEquality::NullEqualsNothing,
+ true, // null_aware = true
+ )?;
+
+ let stream = join.execute(0, task_ctx)?;
+ let batches = common::collect(stream).await?;
+
+ // Expected: (3, 4, 30) and (5, 6, 50)
+ // Row (1, 2, 10) matches right side, so filtered out
+ // Row (NULL, 8, 0) has NULL in a column, so filtered out
+ allow_duplicates! {
+ assert_snapshot!(batches_to_sort_string(&batches), @r"
+ +---+---+-------+
+ | a | b | dummy |
+ +---+---+-------+
+ | 3 | 4 | 30 |
+ | 5 | 6 | 50 |
+ +---+---+-------+
+ ");
+ }
+ Ok(())
+ }
+
+ /// Test null-aware anti join with three columns (a, b, c) NOT IN (SELECT x, y, z FROM ...)
+ #[apply(hash_join_exec_configs)]
+ #[tokio::test]
+ async fn test_null_aware_anti_join_three_columns(batch_size: usize) -> Result<()> {
+ let task_ctx = prepare_task_ctx(batch_size, false);
+
+ // Build left table
+ let left = build_table_three_cols(
+ ("a", &vec![Some(1), Some(4), Some(7)]),
+ ("b", &vec![Some(2), Some(5), Some(8)]),
+ ("c", &vec![Some(3), Some(6), Some(9)]),
+ );
+
+ // Build right table with NULL in third column
+ let right = build_table_three_cols(
+ ("x", &vec![Some(1)]),
+ ("y", &vec![Some(2)]),
+ ("z", &vec![None]), // NULL in z column
);
+
+ let on = vec![
+ (
+ Arc::new(Column::new_with_schema("a", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("x", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("b", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("y", &right.schema())?) as _,
+ ),
+ (
+ Arc::new(Column::new_with_schema("c", &left.schema())?) as _,
+ Arc::new(Column::new_with_schema("z", &right.schema())?) as _,
+ ),
+ ];
+
+ // Create null-aware anti join
+ let join = HashJoinExec::try_new(
+ left,
+ right,
+ on,
+ None,
+ &JoinType::LeftAnti,
+ None,
+ PartitionMode::CollectLeft,
+ NullEquality::NullEqualsNothing,
+ true, // null_aware = true
+ )?;
+
+ let stream = join.execute(0, task_ctx)?;
+ let batches = common::collect(stream).await?;
+
+ // Expected: empty result (probe side has NULL in z column)
+ allow_duplicates! {
+ assert_snapshot!(batches_to_sort_string(&batches), @r"
+ ++
+ ++
+ ");
+ }
+ Ok(())
+ }
+
+ /// Helper to build a table with three columns supporting nullable values
+ fn build_table_three_cols(
+ a: (&str, &Vec