diff --git a/native/core/src/execution/planner.rs b/native/core/src/execution/planner.rs index 5af31fcc22..27240af03c 100644 --- a/native/core/src/execution/planner.rs +++ b/native/core/src/execution/planner.rs @@ -2438,7 +2438,7 @@ impl PhysicalPlanner { sort_phy_exprs, window_frame.into(), input_schema, - false, // TODO: Ignore nulls + spark_expr.ignore_nulls, false, // TODO: Spark does not support DISTINCT ... OVER None, ) @@ -2492,6 +2492,12 @@ impl PhysicalPlanner { .udaf(name) .map(WindowFunctionDefinition::AggregateUDF) .ok() + .or_else(|| { + registry + .udwf(name) + .map(WindowFunctionDefinition::WindowUDF) + .ok() + }) } /// Create a DataFusion physical partitioning from Spark physical partitioning diff --git a/native/proto/src/proto/operator.proto b/native/proto/src/proto/operator.proto index 344b9f0f21..fb438b26a4 100644 --- a/native/proto/src/proto/operator.proto +++ b/native/proto/src/proto/operator.proto @@ -369,6 +369,7 @@ message WindowExpr { spark.spark_expression.Expr built_in_window_function = 1; spark.spark_expression.AggExpr agg_func = 2; WindowSpecDefinition spec = 3; + bool ignore_nulls = 4; } enum WindowFrameType { diff --git a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala index 17f5c62469..e642bafa4f 100644 --- a/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala +++ b/spark/src/main/scala/org/apache/spark/sql/comet/CometWindowExec.scala @@ -21,7 +21,7 @@ package org.apache.spark.sql.comet import scala.jdk.CollectionConverters._ -import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CurrentRow, Expression, NamedExpression, RangeFrame, RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression} +import org.apache.spark.sql.catalyst.expressions.{Alias, Attribute, AttributeReference, AttributeSet, CurrentRow, Expression, FrameLessOffsetWindowFunction, Lag, Lead, NamedExpression, RangeFrame, RowFrame, SortOrder, SpecifiedWindowFrame, UnboundedFollowing, UnboundedPreceding, WindowExpression} import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, Complete, Count, Max, Min, Sum} import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan @@ -36,7 +36,7 @@ import org.apache.comet.{CometConf, ConfigEntry} import org.apache.comet.CometSparkSessionExtensions.withInfo import org.apache.comet.serde.{AggSerde, CometOperatorSerde, Incompatible, OperatorOuterClass, SupportLevel} import org.apache.comet.serde.OperatorOuterClass.Operator -import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto} +import org.apache.comet.serde.QueryPlanSerde.{aggExprToProto, exprToProto, scalarFunctionExprToProto} object CometWindowExec extends CometOperatorSerde[WindowExec] { @@ -72,7 +72,12 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] { return None } - if (op.partitionSpec.nonEmpty && op.orderSpec.nonEmpty && + // Offset window functions (LAG, LEAD) support arbitrary partition and order specs, so skip + // the validatePartitionAndSortSpecsForWindowFunc check which requires partition columns to + // equal order columns. That stricter check is only needed for aggregate window functions. + val hasOnlyOffsetFunctions = winExprs.nonEmpty && + winExprs.forall(e => e.windowFunction.isInstanceOf[FrameLessOffsetWindowFunction]) + if (!hasOnlyOffsetFunctions && op.partitionSpec.nonEmpty && op.orderSpec.nonEmpty && !validatePartitionAndSortSpecsForWindowFunc(op.partitionSpec, op.orderSpec, op)) { return None } @@ -141,12 +146,27 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] { } }.toArray - val (aggExpr, builtinFunc) = if (aggregateExpressions.nonEmpty) { + val (aggExpr, builtinFunc, ignoreNulls) = if (aggregateExpressions.nonEmpty) { val modes = aggregateExpressions.map(_.mode).distinct assert(modes.size == 1 && modes.head == Complete) - (aggExprToProto(aggregateExpressions.head, output, true, conf), None) + (aggExprToProto(aggregateExpressions.head, output, true, conf), None, false) } else { - (None, exprToProto(windowExpr.windowFunction, output)) + windowExpr.windowFunction match { + case lag: Lag => + val inputExpr = exprToProto(lag.input, output) + val offsetExpr = exprToProto(lag.inputOffset, output) + val defaultExpr = exprToProto(lag.default, output) + val func = scalarFunctionExprToProto("lag", inputExpr, offsetExpr, defaultExpr) + (None, func, lag.ignoreNulls) + case lead: Lead => + val inputExpr = exprToProto(lead.input, output) + val offsetExpr = exprToProto(lead.offset, output) + val defaultExpr = exprToProto(lead.default, output) + val func = scalarFunctionExprToProto("lead", inputExpr, offsetExpr, defaultExpr) + (None, func, lead.ignoreNulls) + case _ => + (None, exprToProto(windowExpr.windowFunction, output), false) + } } if (aggExpr.isEmpty && builtinFunc.isEmpty) { @@ -254,6 +274,7 @@ object CometWindowExec extends CometOperatorSerde[WindowExec] { .newBuilder() .setBuiltInWindowFunction(builtinFunc.get) .setSpec(spec) + .setIgnoreNulls(ignoreNulls) .build()) } else if (aggExpr.isDefined) { Some( diff --git a/spark/src/test/resources/sql-tests/expressions/window/lag_lead.sql b/spark/src/test/resources/sql-tests/expressions/window/lag_lead.sql new file mode 100644 index 0000000000..8f85174301 --- /dev/null +++ b/spark/src/test/resources/sql-tests/expressions/window/lag_lead.sql @@ -0,0 +1,294 @@ +-- Licensed to the Apache Software Foundation (ASF) under one +-- or more contributor license agreements. See the NOTICE file +-- distributed with this work for additional information +-- regarding copyright ownership. The ASF licenses this file +-- to you under the Apache License, Version 2.0 (the +-- "License"); you may not use this file except in compliance +-- with the License. You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Config: spark.comet.operator.WindowExec.allowIncompatible=true + +-- ============================================================ +-- Setup: shared tables +-- ============================================================ + +statement +CREATE TABLE test_lag_lead(id int, val int, grp string) USING parquet + +statement +INSERT INTO test_lag_lead VALUES + (1, 10, 'a'), + (2, 20, 'a'), + (3, 30, 'a'), + (4, 40, 'b'), + (5, 50, 'b') + +statement +CREATE TABLE test_nulls(id int, val int, grp string) USING parquet + +statement +INSERT INTO test_nulls VALUES + (1, NULL, 'a'), + (2, 10, 'a'), + (3, NULL, 'a'), + (4, 20, 'a'), + (5, NULL, 'b'), + (6, 30, 'b'), + (7, NULL, 'b') + +statement +CREATE TABLE test_all_nulls(id int, val int, grp string) USING parquet + +statement +INSERT INTO test_all_nulls VALUES + (1, NULL, 'a'), + (2, NULL, 'a'), + (3, NULL, 'b'), + (4, 1, 'b') + +statement +CREATE TABLE test_single_row(id int, val int) USING parquet + +statement +INSERT INTO test_single_row VALUES (1, 42) + +statement +CREATE TABLE test_types( + id int, + i_val int, + l_val bigint, + d_val double, + s_val string, + grp string +) USING parquet + +statement +INSERT INTO test_types VALUES + (1, NULL, NULL, NULL, NULL, 'a'), + (2, 1, 100, 1.5, 'foo', 'a'), + (3, 2, 200, 2.5, 'bar', 'a'), + (4, NULL, NULL, NULL, NULL, 'b'), + (5, 3, 300, 3.5, 'baz', 'b') + +-- ############################################################ +-- LAG +-- ############################################################ + +-- ============================================================ +-- lag: basic (default offset = 1) +-- ============================================================ + +query +SELECT id, val, + LAG(val) OVER (ORDER BY id) as lag_val +FROM test_lag_lead + +query +SELECT grp, id, val, + LAG(val) OVER (PARTITION BY grp ORDER BY id) as lag_val +FROM test_lag_lead + +-- ============================================================ +-- lag: with explicit offset +-- ============================================================ + +query +SELECT id, val, + LAG(val, 2) OVER (ORDER BY id) as lag_val_2 +FROM test_lag_lead + +-- ============================================================ +-- lag: with offset and default value +-- ============================================================ + +query +SELECT id, val, + LAG(val, 2, -1) OVER (ORDER BY id) as lag_val_2 +FROM test_lag_lead + +-- ============================================================ +-- lag IGNORE NULLS: basic +-- ============================================================ + +query +SELECT id, val, + LAG(val) IGNORE NULLS OVER (ORDER BY id) as lag_val +FROM test_nulls + +query +SELECT grp, id, val, + LAG(val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) as lag_val +FROM test_nulls + +-- ============================================================ +-- lag IGNORE NULLS: all values null in a group +-- ============================================================ + +query +SELECT grp, id, val, + LAG(val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) as lag_val +FROM test_all_nulls + +-- ============================================================ +-- lag IGNORE NULLS: single row +-- ============================================================ + +query +SELECT id, val, + LAG(val) IGNORE NULLS OVER (ORDER BY id) as lag_val +FROM test_single_row + +-- ============================================================ +-- lag IGNORE NULLS: multiple data types +-- ============================================================ + +query +SELECT grp, id, + LAG(i_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id), + LAG(l_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id), + LAG(d_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id), + LAG(s_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) +FROM test_types + +-- ============================================================ +-- lag IGNORE NULLS: with offset > 1 +-- ============================================================ + +query +SELECT id, val, + LAG(val, 2) IGNORE NULLS OVER (ORDER BY id) as lag_val_2 +FROM test_nulls + +-- ============================================================ +-- lag: contrast IGNORE NULLS vs RESPECT NULLS +-- ============================================================ + +query +SELECT id, val, + LAG(val) OVER (ORDER BY id) as lag_respect, + LAG(val) IGNORE NULLS OVER (ORDER BY id) as lag_ignore +FROM test_nulls + +-- ############################################################ +-- LEAD +-- ############################################################ + +-- ============================================================ +-- lead: basic (default offset = 1) +-- ============================================================ + +query +SELECT id, val, + LEAD(val) OVER (ORDER BY id) as lead_val +FROM test_lag_lead + +query +SELECT grp, id, val, + LEAD(val) OVER (PARTITION BY grp ORDER BY id) as lead_val +FROM test_lag_lead + +-- ============================================================ +-- lead: with explicit offset +-- ============================================================ + +query +SELECT id, val, + LEAD(val, 2) OVER (ORDER BY id) as lead_val_2 +FROM test_lag_lead + +-- ============================================================ +-- lead: with offset and default value +-- ============================================================ + +query +SELECT id, val, + LEAD(val, 2, -1) OVER (ORDER BY id) as lead_val_2 +FROM test_lag_lead + +-- ============================================================ +-- lead IGNORE NULLS: basic +-- ============================================================ + +query +SELECT id, val, + LEAD(val) IGNORE NULLS OVER (ORDER BY id) as lead_val +FROM test_nulls + +query +SELECT grp, id, val, + LEAD(val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) as lead_val +FROM test_nulls + +-- ============================================================ +-- lead IGNORE NULLS: all values null in a group +-- ============================================================ + +query +SELECT grp, id, val, + LEAD(val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) as lead_val +FROM test_all_nulls + +-- ============================================================ +-- lead IGNORE NULLS: single row +-- ============================================================ + +query +SELECT id, val, + LEAD(val) IGNORE NULLS OVER (ORDER BY id) as lead_val +FROM test_single_row + +-- ============================================================ +-- lead IGNORE NULLS: multiple data types +-- ============================================================ + +query +SELECT grp, id, + LEAD(i_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id), + LEAD(l_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id), + LEAD(d_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id), + LEAD(s_val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) +FROM test_types + +-- ============================================================ +-- lead IGNORE NULLS: with offset > 1 +-- ============================================================ + +query +SELECT id, val, + LEAD(val, 2) IGNORE NULLS OVER (ORDER BY id) as lead_val_2 +FROM test_nulls + +-- ============================================================ +-- lead: contrast IGNORE NULLS vs RESPECT NULLS +-- ============================================================ + +query +SELECT id, val, + LEAD(val) OVER (ORDER BY id) as lead_respect, + LEAD(val) IGNORE NULLS OVER (ORDER BY id) as lead_ignore +FROM test_nulls + +-- ############################################################ +-- LAG + LEAD combined +-- ############################################################ + +query +SELECT id, val, + LAG(val) OVER (ORDER BY id) as lag_val, + LEAD(val) OVER (ORDER BY id) as lead_val +FROM test_lag_lead + +query +SELECT grp, id, val, + LAG(val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) as lag_ignore, + LEAD(val) IGNORE NULLS OVER (PARTITION BY grp ORDER BY id) as lead_ignore +FROM test_nulls diff --git a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala index a1a24f4c2b..23acc2b16d 100644 --- a/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala +++ b/spark/src/test/scala/org/apache/comet/exec/CometWindowExecSuite.scala @@ -26,6 +26,7 @@ import org.apache.hadoop.fs.Path import org.apache.spark.sql.{CometTestBase, Row} import org.apache.spark.sql.comet.CometWindowExec import org.apache.spark.sql.comet.execution.shuffle.CometShuffleExchangeExec +import org.apache.spark.sql.execution.window.WindowExec import org.apache.spark.sql.expressions.Window import org.apache.spark.sql.functions.{count, lead, sum} import org.apache.spark.sql.internal.SQLConf @@ -605,87 +606,131 @@ class CometWindowExecSuite extends CometTestBase { } } - // TODO: LAG produces incorrect results - ignore("window: LAG with default offset") { - withTempDir { dir => - (0 until 30) - .map(i => (i % 3, i % 5, i)) - .toDF("a", "b", "c") - .repartition(3) - .write - .mode("overwrite") - .parquet(dir.toString) - - spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") - val df = sql(""" - SELECT a, b, c, - LAG(c) OVER (PARTITION BY a ORDER BY b) as lag_c - FROM window_test - """) - checkSparkAnswerAndOperator(df) + test("window: LAG with default offset") { + withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") { + withTempDir { dir => + (0 until 30) + .map(i => (i % 3, i % 5, i)) + .toDF("a", "b", "c") + .repartition(3) + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") + val df = sql(""" + SELECT a, b, c, + LAG(c) OVER (PARTITION BY a ORDER BY b, c) as lag_c + FROM window_test + """) + checkSparkAnswerAndOperator(df) + } } } - // TODO: LAG with offset 2 produces incorrect results - ignore("window: LAG with offset 2 and default value") { - withTempDir { dir => - (0 until 30) - .map(i => (i % 3, i % 5, i)) - .toDF("a", "b", "c") - .repartition(3) - .write - .mode("overwrite") - .parquet(dir.toString) - - spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") - val df = sql(""" - SELECT a, b, c, - LAG(c, 2, -1) OVER (PARTITION BY a ORDER BY b) as lag_c_2 - FROM window_test - """) - checkSparkAnswerAndOperator(df) + test("window: LAG with offset 2 and default value") { + withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") { + withTempDir { dir => + (0 until 30) + .map(i => (i % 3, i % 5, i)) + .toDF("a", "b", "c") + .repartition(3) + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") + val df = sql(""" + SELECT a, b, c, + LAG(c, 2, -1) OVER (PARTITION BY a ORDER BY b, c) as lag_c_2 + FROM window_test + """) + checkSparkAnswerAndOperator(df) + } } } - // TODO: LEAD produces incorrect results - ignore("window: LEAD with default offset") { - withTempDir { dir => - (0 until 30) - .map(i => (i % 3, i % 5, i)) - .toDF("a", "b", "c") - .repartition(3) - .write - .mode("overwrite") - .parquet(dir.toString) + test("window: LAG with IGNORE NULLS") { + withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") { + withTempDir { dir => + Seq((1, 1, Some(10)), (1, 2, None), (1, 3, Some(30)), (2, 1, None), (2, 2, Some(20))) + .toDF("a", "b", "c") + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") + val df = sql(""" + SELECT a, b, c, + LAG(c) IGNORE NULLS OVER (PARTITION BY a ORDER BY b) as lag_c + FROM window_test + """) + checkSparkAnswerAndOperator(df) + } + } + } - spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") - val df = sql(""" - SELECT a, b, c, - LEAD(c) OVER (PARTITION BY a ORDER BY b) as lead_c - FROM window_test - """) - checkSparkAnswerAndOperator(df) + test("window: LEAD with default offset") { + withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") { + withTempDir { dir => + (0 until 30) + .map(i => (i % 3, i % 5, i)) + .toDF("a", "b", "c") + .repartition(3) + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") + val df = sql(""" + SELECT a, b, c, + LEAD(c) OVER (PARTITION BY a ORDER BY b, c) as lead_c + FROM window_test + """) + checkSparkAnswerAndOperator(df) + } } } - // TODO: LEAD with offset 2 produces incorrect results - ignore("window: LEAD with offset 2 and default value") { - withTempDir { dir => - (0 until 30) - .map(i => (i % 3, i % 5, i)) - .toDF("a", "b", "c") - .repartition(3) - .write - .mode("overwrite") - .parquet(dir.toString) + test("window: LEAD with offset 2 and default value") { + withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") { + withTempDir { dir => + (0 until 30) + .map(i => (i % 3, i % 5, i)) + .toDF("a", "b", "c") + .repartition(3) + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") + val df = sql(""" + SELECT a, b, c, + LEAD(c, 2, -1) OVER (PARTITION BY a ORDER BY b, c) as lead_c_2 + FROM window_test + """) + checkSparkAnswerAndOperator(df) + } + } + } - spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") - val df = sql(""" - SELECT a, b, c, - LEAD(c, 2, -1) OVER (PARTITION BY a ORDER BY b) as lead_c_2 - FROM window_test - """) - checkSparkAnswerAndOperator(df) + test("window: LEAD with IGNORE NULLS") { + withSQLConf(CometConf.getOperatorAllowIncompatConfigKey(classOf[WindowExec]) -> "true") { + withTempDir { dir => + Seq((1, 1, Some(10)), (1, 2, None), (1, 3, Some(30)), (2, 1, None), (2, 2, Some(20))) + .toDF("a", "b", "c") + .write + .mode("overwrite") + .parquet(dir.toString) + + spark.read.parquet(dir.toString).createOrReplaceTempView("window_test") + val df = sql(""" + SELECT a, b, c, + LEAD(c) IGNORE NULLS OVER (PARTITION BY a ORDER BY b) as lead_c + FROM window_test + """) + checkSparkAnswerAndOperator(df) + } } }