Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions datafusion/expr/src/logical_plan/invariants.rs
Original file line number Diff line number Diff line change
Expand Up @@ -222,14 +222,31 @@ pub fn check_subquery_expr(
check_correlations_in_subquery(inner_plan)
} else {
if let Expr::InSubquery(subquery) = expr {
// InSubquery should only return one column
if subquery.subquery.subquery.schema().fields().len() > 1 {
// InSubquery should only return one column UNLESS the left expression is a struct
// (multi-column IN like: (a, b) NOT IN (SELECT x, y FROM ...))
let is_struct = matches!(*subquery.expr, Expr::ScalarFunction(ref func) if func.func.name() == "struct");

let num_subquery_cols = subquery.subquery.subquery.schema().fields().len();

if !is_struct && num_subquery_cols > 1 {
return plan_err!(
"InSubquery should only return one column, but found {}: {}",
subquery.subquery.subquery.schema().fields().len(),
num_subquery_cols,
subquery.subquery.subquery.schema().field_names().join(", ")
);
}

// For struct expressions, validate that the number of fields matches
if is_struct && let Expr::ScalarFunction(ref func) = *subquery.expr {
let num_tuple_cols = func.args.len();
if num_tuple_cols != num_subquery_cols {
return plan_err!(
"The number of columns in the tuple ({}) must match the number of columns in the subquery ({})",
num_tuple_cols,
num_subquery_cols
);
}
}
}
if let Expr::SetComparison(set_comparison) = expr
&& set_comparison.subquery.subquery.schema().fields().len() > 1
Expand Down
54 changes: 37 additions & 17 deletions datafusion/optimizer/src/analyzer/type_coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -482,23 +482,43 @@ impl TreeNodeRewriter for TypeCoercionRewriter<'_> {
Arc::unwrap_or_clone(subquery.subquery),
)?
.data;
let expr_type = expr.get_type(self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
plan_datafusion_err!(
"expr type {expr_type} can't cast to {subquery_type} in InSubquery"
),
)?;
let new_subquery = Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
spans: subquery.spans,
};
Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
Box::new(expr.cast_to(&common_type, self.schema)?),
cast_subquery(new_subquery, &common_type)?,
negated,
))))

// Check if this is a multi-column IN (struct expression)
let is_struct = matches!(*expr, Expr::ScalarFunction(ref func) if func.func.name() == "struct");

if is_struct {
// For multi-column IN, we don't need type coercion at this level
// The decorrelation phase will handle this by creating join conditions
let new_subquery = Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
spans: subquery.spans,
};
Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
expr,
new_subquery,
negated,
))))
} else {
// Single-column IN: apply type coercion as before
let expr_type = expr.get_type(self.schema)?;
let subquery_type = new_plan.schema().field(0).data_type();
let common_type = comparison_coercion(&expr_type, subquery_type).ok_or(
plan_datafusion_err!(
"expr type {expr_type} can't cast to {subquery_type} in InSubquery"
),
)?;
let new_subquery = Subquery {
subquery: Arc::new(new_plan),
outer_ref_columns: subquery.outer_ref_columns,
spans: subquery.spans,
};
Ok(Transformed::yes(Expr::InSubquery(InSubquery::new(
Box::new(expr.cast_to(&common_type, self.schema)?),
cast_subquery(new_subquery, &common_type)?,
negated,
))))
}
}
Expr::SetComparison(SetComparison {
expr,
Expand Down
100 changes: 95 additions & 5 deletions datafusion/optimizer/src/decorrelate_predicate_subquery.rs
Original file line number Diff line number Diff line change
Expand Up @@ -386,9 +386,57 @@ fn build_join(
right,
})),
) => {
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
let in_predicate = Expr::eq(left.deref().clone(), Expr::Column(right_col));
in_predicate.and(join_filter)
// Check if this is a multi-column IN (struct expression)
if let Expr::ScalarFunction(func) = left.deref() {
if func.func.name() == "struct" {
// Decompose struct into individual field comparisons
let struct_args = &func.args;

// The right side should be the subquery result
// Note: After pull-up, the subquery may have additional correlated columns
// We only care about the first N columns that match our struct fields
let subquery_fields = sub_query_alias.schema().fields();

if struct_args.len() > subquery_fields.len() {
return plan_err!(
"Struct field count ({}) exceeds subquery column count ({})",
struct_args.len(),
subquery_fields.len()
);
}

// Create equality conditions for each field
let mut conditions = Vec::new();
for (i, arg) in struct_args.iter().enumerate() {
let field = &subquery_fields[i];
let right_col = Expr::Column(Column::new(
Some(alias.clone()),
field.name().to_string(),
));
conditions.push(Expr::eq(arg.clone(), right_col));
}

// Combine all conditions with AND
let in_predicate = conditions
.into_iter()
.reduce(|acc, cond| acc.and(cond))
.unwrap_or_else(|| lit(true));

in_predicate.and(join_filter)
} else {
// Regular scalar function, handle as before
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
let in_predicate =
Expr::eq(left.deref().clone(), Expr::Column(right_col));
in_predicate.and(join_filter)
}
} else {
// Not a struct, handle as before
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
let in_predicate =
Expr::eq(left.deref().clone(), Expr::Column(right_col));
in_predicate.and(join_filter)
}
}
(Some(join_filter), _) => join_filter,
(
Expand All @@ -399,9 +447,51 @@ fn build_join(
right,
})),
) => {
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
// Check if this is a multi-column IN (struct expression)
if let Expr::ScalarFunction(func) = left.deref() {
if func.func.name() == "struct" {
// Decompose struct into individual field comparisons
let struct_args = &func.args;

// The right side should be the subquery result
// Note: After pull-up, the subquery may have additional correlated columns
// We only care about the first N columns that match our struct fields
let subquery_fields = sub_query_alias.schema().fields();

if struct_args.len() > subquery_fields.len() {
return plan_err!(
"Struct field count ({}) exceeds subquery column count ({})",
struct_args.len(),
subquery_fields.len()
);
}

// Create equality conditions for each field
let mut conditions = Vec::new();
for (i, arg) in struct_args.iter().enumerate() {
let field = &subquery_fields[i];
let right_col = Expr::Column(Column::new(
Some(alias.clone()),
field.name().to_string(),
));
conditions.push(Expr::eq(arg.clone(), right_col));
}

Expr::eq(left.deref().clone(), Expr::Column(right_col))
// Combine all conditions with AND
conditions
.into_iter()
.reduce(|acc, cond| acc.and(cond))
.unwrap_or_else(|| lit(true))
} else {
// Regular scalar function, handle as before
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
Expr::eq(left.deref().clone(), Expr::Column(right_col))
}
} else {
// Not a struct, handle as before
let right_col = create_col_from_scalar_expr(right.deref(), alias)?;
Expr::eq(left.deref().clone(), Expr::Column(right_col))
}
}
(None, None) => lit(true),
_ => return Ok(None),
Expand Down
Loading