From 0723a778f07462531fe7c6150f9bfd81d328f76f Mon Sep 17 00:00:00 2001 From: buraksenn Date: Fri, 13 Mar 2026 18:30:41 +0300 Subject: [PATCH 1/3] feat: implemented topk cardinality aggregate estimation --- .../physical-plan/src/aggregates/mod.rs | 558 +++++++++++++++++- 1 file changed, 548 insertions(+), 10 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index e573a7f85ddf9..6fa78a1981638 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1037,6 +1037,23 @@ impl AggregateExec { &self.input_order_mode } + /// Estimates output statistics for this aggregate node. + /// + /// For grouped aggregations with known input row count > 1, the output row + /// count is estimated as: + /// + /// ```text + /// output_rows = input_rows // baseline + /// output_rows = min(output_rows, product(NDV_i + nulls_i) * sets) // if NDV available + /// output_rows = min(output_rows, limit) // if TopK active + /// ``` + /// + /// The NDV estimation is heavily inspired by Spark and Trino. + /// - Multiplies distinct value counts of all group-by columns + /// - Adds +1 per column when nulls are present (a null group is a distinct output row) + /// - Caps the result by the input row count + /// - Requires NDV stats for ALL group-by columns; if any column lacks stats, + /// falls back to `input_rows` as the upper bound fn statistics_inner(&self, child_statistics: &Statistics) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here @@ -1050,15 +1067,12 @@ impl AggregateExec { for (idx, (expr, _)) in self.group_by.expr.iter().enumerate() { if let Some(col) = expr.as_any().downcast_ref::() { - column_statistics[idx].max_value = child_statistics.column_statistics - [col.index()] - .max_value - .clone(); - - column_statistics[idx].min_value = child_statistics.column_statistics - [col.index()] - .min_value - .clone(); + let child_col_stats = + &child_statistics.column_statistics[col.index()]; + column_statistics[idx].max_value = child_col_stats.max_value.clone(); + column_statistics[idx].min_value = child_col_stats.min_value.clone(); + column_statistics[idx].distinct_count = + child_col_stats.distinct_count; } } @@ -1083,7 +1097,23 @@ impl AggregateExec { let num_rows = if let Some(value) = child_statistics.num_rows.get_value() { if *value > 1 { - child_statistics.num_rows.to_inexact() + let mut num_rows = child_statistics.num_rows.to_inexact(); + + if !self.group_by.expr.is_empty() { + let ndv_product = self.compute_group_ndv(child_statistics); + if let Some(ndv) = ndv_product { + let grouping_set_num = self.group_by.groups.len(); + let ndv_estimate = ndv.saturating_mul(grouping_set_num); + num_rows = num_rows.map(|n| n.min(ndv_estimate)); + } + } + + // If TopK mode is active, cap output rows by the limit + if let Some(limit_opts) = &self.limit_options { + num_rows = num_rows.map(|n| n.min(limit_opts.limit)); + } + + num_rows } else if *value == 0 { child_statistics.num_rows } else { @@ -1091,6 +1121,8 @@ impl AggregateExec { let grouping_set_num = self.group_by.groups.len(); child_statistics.num_rows.map(|x| x * grouping_set_num) } + } else if let Some(limit_opts) = &self.limit_options { + Precision::Inexact(limit_opts.limit) } else { Precision::Absent }; @@ -1113,6 +1145,24 @@ impl AggregateExec { } } + /// Computes `product(NDV_i + null_adjustment_i)` across all group-by columns. + /// Returns `None` if any group-by column is not a direct column reference + /// or lacks `distinct_count` stats. + fn compute_group_ndv(&self, child_statistics: &Statistics) -> Option { + let mut product: usize = 1; + for (expr, _) in self.group_by.expr.iter() { + let col = expr.as_any().downcast_ref::()?; + let col_stats = &child_statistics.column_statistics[col.index()]; + let ndv = *col_stats.distinct_count.get_value()?; + let null_adjustment = match col_stats.null_count.get_value() { + Some(&n) if n > 0 => 1usize, + _ => 0, + }; + product = product.saturating_mul(ndv.saturating_add(null_adjustment)); + } + Some(product) + } + /// Check if dynamic filter is possible for the current plan node. /// - If yes, init one inside `AggregateExec`'s `dynamic_filter` field. /// - If not supported, `self.dynamic_filter` should be kept `None` @@ -3795,6 +3845,494 @@ mod tests { Ok(()) } + #[test] + fn test_aggregate_cardinality_estimation() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + struct TestCase { + name: &'static str, + input_rows: Precision, + col_a_stats: ColumnStatistics, + col_b_stats: ColumnStatistics, + group_by_cols: Vec<&'static str>, + limit_options: Option, + expected_num_rows: Precision, + } + + let cases = vec![ + // --- NDV-based estimation --- + TestCase { + name: "single group-by col with NDV tightens estimate", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(500), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(500), + }, + TestCase { + name: "multi-col group-by multiplies NDVs", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(5_000), + }, + TestCase { + name: "NDV product capped by input rows", + input_rows: Precision::Exact(200), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(200), + }, + TestCase { + name: "null adjustment adds +1 per column", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(99), + null_count: Precision::Exact(10), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + // 99 + 1 (null adjustment) = 100 + expected_num_rows: Precision::Inexact(100), + }, + TestCase { + name: "null adjustment on multiple columns", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(99), + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics { + distinct_count: Precision::Exact(49), + null_count: Precision::Exact(3), + ..ColumnStatistics::new_unknown() + }, + group_by_cols: vec!["a", "b"], + limit_options: None, + // (99+1) * (49+1) = 100 * 50 = 5000 + expected_num_rows: Precision::Inexact(5_000), + }, + TestCase { + name: "zero null_count means no adjustment", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + null_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(100), + }, + // --- Bail-out: partial NDV stats (Spark-style) --- + TestCase { + name: "bail out when one group-by col lacks NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a", "b"], + limit_options: None, + expected_num_rows: Precision::Inexact(1_000_000), + }, + TestCase { + name: "bail out when all group-by cols lack NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(1_000_000), + }, + // --- TopK limit capping --- + TestCase { + name: "TopK limit caps output rows", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + TestCase { + name: "NDV + TopK limit: min(NDV, limit) when NDV < limit", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(5), + }, + TestCase { + name: "NDV + TopK limit: min(NDV, limit) when limit < NDV", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(500), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + // --- Absent input rows --- + TestCase { + name: "absent input rows without limit stays absent", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Absent, + }, + TestCase { + name: "absent input rows with TopK limit gives inexact(limit)", + input_rows: Precision::Absent, + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + // --- No group-by (global aggregation) --- + TestCase { + name: "no group-by cols (Final mode) returns Exact(1)", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec![], + limit_options: None, + expected_num_rows: Precision::Exact(1), + }, + // --- One input row --- + TestCase { + name: "one input row returns Exact(1)", + input_rows: Precision::Exact(1), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Exact(1), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Exact(1), + }, + // --- Zero input rows --- + TestCase { + name: "zero input rows returns Exact(0)", + input_rows: Precision::Exact(0), + col_a_stats: ColumnStatistics::new_unknown(), + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Exact(0), + }, + // --- Inexact NDV stats --- + TestCase { + name: "inexact NDV still used for estimation", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Inexact(200), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: None, + expected_num_rows: Precision::Inexact(200), + }, + TestCase { + name: "inexact NDV combined with limit", + input_rows: Precision::Exact(1_000_000), + col_a_stats: ColumnStatistics { + distinct_count: Precision::Inexact(200), + ..ColumnStatistics::new_unknown() + }, + col_b_stats: ColumnStatistics::new_unknown(), + group_by_cols: vec!["a"], + limit_options: Some(LimitOptions::new(10)), + expected_num_rows: Precision::Inexact(10), + }, + ]; + + for case in cases { + let input_stats = Statistics { + num_rows: case.input_rows, + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + case.col_a_stats.clone(), + case.col_b_stats.clone(), + ], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + let group_by = if case.group_by_cols.is_empty() { + PhysicalGroupBy::default() + } else { + PhysicalGroupBy::new_single( + case.group_by_cols + .iter() + .map(|name| { + ( + col(name, &schema).unwrap() as Arc, + name.to_string(), + ) + }) + .collect(), + ) + }; + + let mut agg = AggregateExec::try_new( + AggregateMode::Final, + group_by, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + if let Some(limit) = case.limit_options { + agg = agg.with_limit_options(Some(limit)); + } + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.num_rows, case.expected_num_rows, + "FAILED: '{}' — expected {:?}, got {:?}", + case.name, case.expected_num_rows, stats.num_rows + ); + } + + Ok(()) + } + + #[test] + fn test_aggregate_stats_distinct_count_propagation() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input = Arc::new(StatisticsExec::new( + Statistics { + num_rows: Precision::Exact(1000), + total_byte_size: Precision::Inexact(10000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + null_count: Precision::Exact(5), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics::new_unknown(), + ], + }, + (*schema).clone(), + )) as Arc; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![( + col("a", &schema)? as Arc, + "a".to_string(), + )]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("b", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(b)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.column_statistics[0].distinct_count, + Precision::Exact(100), + "distinct_count should be propagated from child for group-by columns" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_grouping_sets() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1_000_000), + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + // CUBE-like grouping set: (a, NULL), (NULL, b), (a, b) — 3 groups + let grouping_set = PhysicalGroupBy::new( + vec![ + (col("a", &schema)? as Arc, "a".to_string()), + (col("b", &schema)? as Arc, "b".to_string()), + ], + vec![ + (lit(ScalarValue::Int32(None)), "a".to_string()), + (lit(ScalarValue::Int32(None)), "b".to_string()), + ], + vec![ + vec![false, true], // (a, NULL) + vec![true, false], // (NULL, b) + vec![false, false], // (a, b) + ], + true, + ); + + let agg = AggregateExec::try_new( + AggregateMode::Final, + grouping_set, + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + // NDV product = 100 * 50 = 5000, multiplied by 3 grouping sets = 15000 + assert_eq!( + stats.num_rows, + Precision::Inexact(15_000), + "grouping sets should multiply NDV product by number of groups" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_non_column_expr_bails_out() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + use datafusion_expr::Operator; + use datafusion_physical_expr::expressions::BinaryExpr; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1_000_000), + total_byte_size: Precision::Inexact(1_000_000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + // GROUP BY (a + b) — not a direct column reference + let expr_a_plus_b: Arc = Arc::new(BinaryExpr::new( + col("a", &schema)?, + Operator::Plus, + col("b", &schema)?, + )); + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![(expr_a_plus_b, "a+b".to_string())]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + // Non-column expr bails out of NDV estimation, falls back to input_rows + assert_eq!( + stats.num_rows, + Precision::Inexact(1_000_000), + "non-column group-by expression should bail out to input_rows" + ); + + Ok(()) + } + #[tokio::test] async fn test_order_is_retained_when_spilling() -> Result<()> { let schema = Arc::new(Schema::new(vec![ From e12bda647227ce69f4933ccda62406132b89c4e6 Mon Sep 17 00:00:00 2001 From: buraksenn Date: Fri, 13 Mar 2026 18:51:09 +0300 Subject: [PATCH 2/3] fix physical optimizer test --- .../core/tests/physical_optimizer/partition_statistics.rs | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/datafusion/core/tests/physical_optimizer/partition_statistics.rs b/datafusion/core/tests/physical_optimizer/partition_statistics.rs index 12ce141b4759a..6abe9a5f6b0cd 100644 --- a/datafusion/core/tests/physical_optimizer/partition_statistics.rs +++ b/datafusion/core/tests/physical_optimizer/partition_statistics.rs @@ -933,7 +933,10 @@ mod test { num_rows: Precision::Exact(0), total_byte_size: Precision::Absent, column_statistics: vec![ - ColumnStatistics::new_unknown(), + ColumnStatistics { + distinct_count: Precision::Exact(0), + ..ColumnStatistics::new_unknown() + }, ColumnStatistics::new_unknown(), ColumnStatistics::new_unknown(), ], From 0d19879bc641aa6293b7a43c22e32b7113c33667 Mon Sep 17 00:00:00 2001 From: buraksenn Date: Tue, 17 Mar 2026 17:28:55 +0300 Subject: [PATCH 3/3] address reviews and adjust comments --- .../physical-plan/src/aggregates/mod.rs | 252 +++++++++++++++--- 1 file changed, 216 insertions(+), 36 deletions(-) diff --git a/datafusion/physical-plan/src/aggregates/mod.rs b/datafusion/physical-plan/src/aggregates/mod.rs index 6fa78a1981638..dd002b90c4b12 100644 --- a/datafusion/physical-plan/src/aggregates/mod.rs +++ b/datafusion/physical-plan/src/aggregates/mod.rs @@ -1043,17 +1043,28 @@ impl AggregateExec { /// count is estimated as: /// /// ```text - /// output_rows = input_rows // baseline - /// output_rows = min(output_rows, product(NDV_i + nulls_i) * sets) // if NDV available - /// output_rows = min(output_rows, limit) // if TopK active + /// ndv = sum over each grouping set of product(max(NDV_i + nulls_i, 1)) + /// output_rows = input_rows // baseline + /// output_rows = min(output_rows, ndv) // if NDV available + /// output_rows = min(output_rows, limit) // if TopK active /// ``` /// - /// The NDV estimation is heavily inspired by Spark and Trino. - /// - Multiplies distinct value counts of all group-by columns - /// - Adds +1 per column when nulls are present (a null group is a distinct output row) - /// - Caps the result by the input row count - /// - Requires NDV stats for ALL group-by columns; if any column lacks stats, - /// falls back to `input_rows` as the upper bound + /// When `input_rows` is absent but NDV is available, falls back to: + /// + /// ```text + /// output_rows = min(ndv, limit) // if both available + /// output_rows = ndv // if only NDV available + /// output_rows = limit // if only limit available + /// ``` + /// + /// NDV estimation details (see [`Self::compute_group_ndv`]): + /// - For each grouping set, only active (non-NULL) columns contribute + /// - Per-column contribution is `max(NDV + null_adj, 1)` where `null_adj` + /// is 1 when nulls are present, 0 otherwise (a null group is a distinct + /// output row; `.max(1)` prevents a zero NDV from zeroing the product) + /// - Per-set products are summed across all grouping sets + /// - Requires NDV stats for ALL active group-by columns; if any lacks stats, + /// falls back to `input_rows` (or `Absent` if that is also unknown) fn statistics_inner(&self, child_statistics: &Statistics) -> Result { // TODO stats: group expressions: // - once expressions will be able to compute their own stats, use it here @@ -1099,13 +1110,10 @@ impl AggregateExec { if *value > 1 { let mut num_rows = child_statistics.num_rows.to_inexact(); - if !self.group_by.expr.is_empty() { - let ndv_product = self.compute_group_ndv(child_statistics); - if let Some(ndv) = ndv_product { - let grouping_set_num = self.group_by.groups.len(); - let ndv_estimate = ndv.saturating_mul(grouping_set_num); - num_rows = num_rows.map(|n| n.min(ndv_estimate)); - } + if !self.group_by.expr.is_empty() + && let Some(ndv) = self.compute_group_ndv(child_statistics) + { + num_rows = num_rows.map(|n| n.min(ndv)); } // If TopK mode is active, cap output rows by the limit @@ -1121,10 +1129,20 @@ impl AggregateExec { let grouping_set_num = self.group_by.groups.len(); child_statistics.num_rows.map(|x| x * grouping_set_num) } - } else if let Some(limit_opts) = &self.limit_options { - Precision::Inexact(limit_opts.limit) } else { - Precision::Absent + let ndv = if !self.group_by.expr.is_empty() { + self.compute_group_ndv(child_statistics) + } else { + None + }; + match (ndv, &self.limit_options) { + (Some(n), Some(limit_opts)) => { + Precision::Inexact(n.min(limit_opts.limit)) + } + (Some(n), None) => Precision::Inexact(n), + (None, Some(limit_opts)) => Precision::Inexact(limit_opts.limit), + (None, None) => Precision::Absent, + } }; let total_byte_size = num_rows @@ -1145,22 +1163,32 @@ impl AggregateExec { } } - /// Computes `product(NDV_i + null_adjustment_i)` across all group-by columns. - /// Returns `None` if any group-by column is not a direct column reference - /// or lacks `distinct_count` stats. + /// Computes the estimated number of distinct groups across all grouping sets. + /// For each grouping set, computes `product(NDV_i + null_adj_i)` for active columns, + /// then sums across all sets. Returns `None` if any active column is not a direct + /// column reference or lacks `distinct_count` stats. + /// When `null_count` is absent or unknown, null_adjustment defaults to 0. fn compute_group_ndv(&self, child_statistics: &Statistics) -> Option { - let mut product: usize = 1; - for (expr, _) in self.group_by.expr.iter() { - let col = expr.as_any().downcast_ref::()?; - let col_stats = &child_statistics.column_statistics[col.index()]; - let ndv = *col_stats.distinct_count.get_value()?; - let null_adjustment = match col_stats.null_count.get_value() { - Some(&n) if n > 0 => 1usize, - _ => 0, - }; - product = product.saturating_mul(ndv.saturating_add(null_adjustment)); + let mut total: usize = 0; + for group_mask in &self.group_by.groups { + let mut set_product: usize = 1; + for (j, (expr, _)) in self.group_by.expr.iter().enumerate() { + if group_mask[j] { + continue; + } + let col = expr.as_any().downcast_ref::()?; + let col_stats = &child_statistics.column_statistics[col.index()]; + let ndv = *col_stats.distinct_count.get_value()?; + let null_adjustment = match col_stats.null_count.get_value() { + Some(&n) if n > 0 => 1usize, + _ => 0, + }; + set_product = set_product + .saturating_mul(ndv.saturating_add(null_adjustment).max(1)); + } + total = total.saturating_add(set_product); } - Some(product) + Some(total) } /// Check if dynamic filter is possible for the current plan node. @@ -4261,11 +4289,12 @@ mod tests { )?; let stats = agg.partition_statistics(None)?; - // NDV product = 100 * 50 = 5000, multiplied by 3 grouping sets = 15000 + // Per-set NDV: (a,NULL)=100, (NULL,b)=50, (a,b)=100*50=5000 + // Total = 100 + 50 + 5000 = 5150 assert_eq!( stats.num_rows, - Precision::Inexact(15_000), - "grouping sets should multiply NDV product by number of groups" + Precision::Inexact(5_150), + "grouping sets should sum per-set NDV products" ); Ok(()) @@ -4333,6 +4362,157 @@ mod tests { Ok(()) } + #[test] + fn test_aggregate_stats_ndv_zero_column() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![ + Field::new("a", DataType::Int32, true), + Field::new("b", DataType::Int32, true), + ])); + + let input_stats = Statistics { + num_rows: Precision::Exact(1_000), + total_byte_size: Precision::Inexact(1_000), + column_statistics: vec![ + ColumnStatistics { + distinct_count: Precision::Exact(0), + null_count: Precision::Exact(1_000), + ..ColumnStatistics::new_unknown() + }, + ColumnStatistics { + distinct_count: Precision::Exact(50), + ..ColumnStatistics::new_unknown() + }, + ], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![ + (col("a", &schema)? as Arc, "a".to_string()), + (col("b", &schema)? as Arc, "b".to_string()), + ]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + // NDV(a)=0 with nulls => max(0+1, 1)=1, NDV(b)=50 => 1*50=50 + assert_eq!( + stats.num_rows, + Precision::Inexact(50), + "all-null column should contribute 1 to the product, not 0" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_absent_num_rows_with_ndv() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let input_stats = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + let agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![( + col("a", &schema)? as Arc, + "a".to_string(), + )]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.num_rows, + Precision::Inexact(100), + "absent num_rows should fall back to NDV estimate" + ); + + Ok(()) + } + + #[test] + fn test_aggregate_stats_absent_num_rows_with_ndv_and_limit() -> Result<()> { + use crate::test::exec::StatisticsExec; + use datafusion_common::ColumnStatistics; + + let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Int32, true)])); + + let input_stats = Statistics { + num_rows: Precision::Absent, + total_byte_size: Precision::Absent, + column_statistics: vec![ColumnStatistics { + distinct_count: Precision::Exact(100), + ..ColumnStatistics::new_unknown() + }], + }; + + let input = Arc::new(StatisticsExec::new(input_stats, (*schema).clone())) + as Arc; + + let mut agg = AggregateExec::try_new( + AggregateMode::Final, + PhysicalGroupBy::new_single(vec![( + col("a", &schema)? as Arc, + "a".to_string(), + )]), + vec![Arc::new( + AggregateExprBuilder::new(count_udaf(), vec![col("a", &schema)?]) + .schema(Arc::clone(&schema)) + .alias("COUNT(a)") + .build()?, + )], + vec![None], + input, + Arc::clone(&schema), + )?; + + agg = agg.with_limit_options(Some(LimitOptions::new(10))); + + let stats = agg.partition_statistics(None)?; + assert_eq!( + stats.num_rows, + Precision::Inexact(10), + "absent num_rows with NDV and limit should return min(ndv, limit)" + ); + + Ok(()) + } + #[tokio::test] async fn test_order_is_retained_when_spilling() -> Result<()> { let schema = Arc::new(Schema::new(vec![