diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index 12ce141b4759a..6abe9a5f6b0cd 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -933,7 +933,10 @@ mod test { num_rows: Precision::Exact(0), total_byte_size: Precision::Absent, column_statistics: vec![ - ColumnStatistics::new_unknown(), + ColumnStatistics { + distinct_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e573a7f85ddf9..6fa78a1981638 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1037,6 +1037,23 @@ impl AggregateExec { &self.input_order_mode } + /// Estimates output statistics for this aggregate node. + /// + /// For grouped aggregations with known input row count > 1, the output row + /// count is estimated as: + /// + /// ```text + /// output_rows = input_rows // baseline + /// output_rows = min(output_rows, product(NDV_i + nulls_i) * sets) // if NDV available + /// output_rows = min(output_rows, limit) // if TopK active + /// ``` + /// + /// The NDV estimation is heavily inspired by Spark and Trino. + /// - Multiplies distinct value counts of all group-by columns + /// - Adds +1 per column when nulls are present (a null group is a distinct output row) + /// - Caps the result by the input row count + /// - Requires NDV stats for ALL group-by columns; if any column lacks stats, + /// falls back to `input_rows` as the upper bound fn statistics_inner(&self, child_statistics: &Statistics) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here @@ -1050,15 +1067,12 @@ impl AggregateExec { for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() { if let Some(col) = expr.as_any().downcast_ref::() { - column_statistics[idx].max_value = child_statistics.column_statistics - [col.index()] - .max_value - .clone(); - - column_statistics[idx].min_value = child_statistics.column_statistics - [col.index()] - .min_value - .clone(); + let child_col_stats = + &child_statistics.column_statistics[col.index()]; + column_statistics[idx].max_value = child_col_stats.max_value.clone(); + column_statistics[idx].min_value = child_col_stats.min_value.clone(); + column_statistics[idx].distinct_count = + child_col_stats.distinct_count; } } @@ -1083,7 +1097,23 @@ impl AggregateExec { let num_rows = if let Some(value) = child_statistics.num_rows.get_value() { if *value > 1 { - child_statistics.num_rows.to_inexact() + let mut num_rows = child_statistics.num_rows.to_inexact(); + + if !self.group_by.expr.is_empty() { + let ndv_product = self.compute_group_ndv(child_statistics); + if let Some(ndv) = ndv_product { + let grouping_set_num = self.group_by.groups.len(); + let ndv_estimate = ndv.saturating_mul(grouping_set_num); + num_rows = num_rows.map(|n| n.min(ndv_estimate)); + } + } + + // If TopK mode is active, cap output rows by the limit + if let Some(limit_opts) = &self.limit_options { + num_rows = num_rows.map(|n| n.min(limit_opts.limit)); + } + + num_rows } else if *value == 0 { child_statistics.num_rows } else { @@ -1091,6 +1121,8 @@ impl AggregateExec { let grouping_set_num = self.group_by.groups.len(); child_statistics.num_rows.map(|x| x * grouping_set_num) } + } else if let Some(limit_opts) = &self.limit_options { + Precision::Inexact(limit_opts.limit) } else { Precision::Absent }; @@ -1113,6 +1145,24 @@ impl AggregateExec { } } + /// Computes `product(NDV_i + null_adjustment_i)` across all group-by columns. + /// Returns `None` if any group-by column is not a direct column reference + /// or lacks `distinct_count` stats. + fn compute_group_ndv(&self, child_statistics: &Statistics) -> Option { + let mut product: usize = 1; + for (expr, _) in self.group_by.expr.iter() { + let col = expr.as_any().downcast_ref::()?; + let col_stats = &child_statistics.column_statistics[col.index()]; + let ndv = *col_stats.distinct_count.get_value()?; + let null_adjustment = match col_stats.null_count.get_value() { + Some(&n) if n > 0 => 1usize, + _ => 0, + }; + product = product.saturating_mul(ndv.saturating_add(null_adjustment)); + } + Some(product) + } + /// Check if dynamic filter is possible for the current plan node. /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field. /// - If not supported, `self.dynamic_filter` should be kept `None` @@ -3795,6 +3845,494 @@ mod tests { Ok(()) } + #[test] + fn test_aggregate_cardinality_estimation() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + struct TestCase { + name: &'static str, + input_rows: Precision, + col_a_stats: ColumnStatistics, + col_b_stats: ColumnStatistics, + group_by_cols: Vec<&'static str>, + limit_options: Option, + expected_num_rows: Precision, + } + + let cases = vec![ + // --- NDV-based estimation --- + TestCase { + name: "single group-by col with NDV tightens estimate", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(500), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(500), + }, + TestCase { + name: "multi-col group-by multiplies NDVs", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(5_000), + }, + TestCase { + name: "NDV product capped by input rows", + input_rows: Precision::Exact(200), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(200), + }, + TestCase { + name: "null adjustment adds +1 per column", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(99), + null_count: Precision::Exact(10), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + // 99 + 1 (null adjustment) = 100 + expected_num_rows: Precision::Inexact(100), + }, + TestCase { + name: "null adjustment on multiple columns", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(99), + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(49), + null_count: Precision::Exact(3), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + // (99+1) * (49+1) = 100 * 50 = 5000 + expected_num_rows: Precision::Inexact(5_000), + }, + TestCase { + name: "zero null_count means no adjustment", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + null_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(100), + }, + // --- Bail-out: partial NDV stats (Spark-style) --- + TestCase { + name: "bail out when one group-by col lacks NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(1_000_000), + }, + TestCase { + name: "bail out when all group-by cols lack NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(1_000_000), + }, + // --- TopK limit capping --- + TestCase { + name: "TopK limit caps output rows", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + TestCase { + name: "NDV + TopK limit: min(NDV, limit) when NDV < limit", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(5), + }, + TestCase { + name: "NDV + TopK limit: min(NDV, limit) when limit < NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(500), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + // --- Absent input rows --- + TestCase { + name: "absent input rows without limit stays absent", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Absent, + }, + TestCase { + name: "absent input rows with TopK limit gives inexact(limit)", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + // --- No group-by (global aggregation) --- + TestCase { + name: "no group-by cols (Final mode) returns Exact(1)", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec![], + limit_options: None, + expected_num_rows: Precision::Exact(1), + }, + // --- One input row --- + TestCase { + name: "one input row returns Exact(1)", + input_rows: Precision::Exact(1), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(1), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Exact(1), + }, + // --- Zero input rows --- + TestCase { + name: "zero input rows returns Exact(0)", + input_rows: Precision::Exact(0), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Exact(0), + }, + // --- Inexact NDV stats --- + TestCase { + name: "inexact NDV still used for estimation", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Inexact(200), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(200), + }, + TestCase { + name: "inexact NDV combined with limit", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Inexact(200), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + ]; + + for case in cases { + let input_stats = Statistics { + num_rows: case.input_rows, + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + case.col_a_stats.clone(), + case.col_b_stats.clone(), + ], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + let group_by = if case.group_by_cols.is_empty() { + PhysicalGroupBy::default() + } else { + PhysicalGroupBy::new_single( + case.group_by_cols + .iter() + .map(|name| { + ( + col(name, &schema).unwrap() as Arc, + name.to_string(), + ) + }) + .collect(), + ) + }; + + let mut agg = AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + if let Some(limit) = case.limit_options { + agg = agg.with_limit_options(Some(limit)); + } + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.num_rows, case.expected_num_rows, + "FAILED: '{}' — expected {:?}, got {:?}", + case.name, case.expected_num_rows, stats.num_rows + ); + } + + Ok(()) + } + + #[test] + fn test_aggregate_stats_distinct_count_propagation() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Inexact(10000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics::new_unknown(), + ], + }, + (*schema).clone(), + )) as Arc; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![( + col("a", &schema)? as Arc, + "a".to_string(), + )]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(b)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.column_statistics[0].distinct_count, + Precision::Exact(100), + "distinct_count should be propagated from child for group-by columns" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_grouping_sets() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1_000_000), + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + // CUBE-like grouping set: (a, NULL), (NULL, b), (a, b) — 3 groups + let grouping_set = PhysicalGroupBy::new( + vec![ + (col("a", &schema)? as Arc, "a".to_string()), + (col("b", &schema)? as Arc, "b".to_string()), + ], + vec![ + (lit(ScalarValue::Int32(None)), "a".to_string()), + (lit(ScalarValue::Int32(None)), "b".to_string()), + ], + vec![ + vec![false, true], // (a, NULL) + vec![true, false], // (NULL, b) + vec![false, false], // (a, b) + ], + true, + ); + + let agg = AggregateExec::try_new( + AggregateMode::Final, + grouping_set, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + // NDV product = 100 * 50 = 5000, multiplied by 3 grouping sets = 15000 + assert_eq!( + stats.num_rows, + Precision::Inexact(15_000), + "grouping sets should multiply NDV product by number of groups" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_non_column_expr_bails_out() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1_000_000), + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + // GROUP BY (a + b) — not a direct column reference + let expr_a_plus_b: Arc = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Plus, + col("b", &schema)?, + )); + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![(expr_a_plus_b, "a+b".to_string())]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + // Non-column expr bails out of NDV estimation, falls back to input_rows + assert_eq!( + stats.num_rows, + Precision::Inexact(1_000_000), + "non-column group-by expression should bail out to input_rows" + ); + + Ok(()) + } + #[tokio::test] async fn test_order_is_retained_when_spilling() -> Result<()> { let schema = Arc::new(Schema::new(vec![