[AURON #2067] Implement native function of instr#2085
[AURON #2067] Implement native function of instr#2085xuzifu666 wants to merge 6 commits intoapache:masterfrom
Conversation
ShreyeshArangath
left a comment
There was a problem hiding this comment.
Let's add a few tests in AuronFunctionSuite.scala so we can verify the behavior with respect to Spark.
| if substr.is_empty() { | ||
| Some(0) | ||
| } else { | ||
| Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) |
There was a problem hiding this comment.
You might want to check this, but I think Rust's str::find() returns a byte offset, not a character offset. Spark's instr returns a 1-based character position. This produces wrong results for any multi-byte UTF-8 input.
| assert(result(1) == 1, "instr('hello world', 'hello') should return 1") | ||
| assert(result(2) == 5, "instr('hello world', 'o') should return 5") | ||
| assert(result(3) == 0, "instr('hello world', 'z') should return 0") | ||
| assert(result(4) == 0, "instr(null, 'test') should return null") |
There was a problem hiding this comment.
Why is the suggesting that its returning null, but its asserting 0?
| assert(result(2) == 5, "instr('hello world', 'o') should return 5") | ||
| assert(result(3) == 0, "instr('hello world', 'z') should return 0") | ||
| assert(result(4) == 0, "instr(null, 'test') should return null") | ||
| assert(result(5) == 0, "instr('test', null) should return null") |
There was a problem hiding this comment.
Pull request overview
Implements native-engine support for Spark SQL’s instr expression by wiring Spark’s StringInstr catalyst expression through Auron’s native converter layer into a new DataFusion extension function, and adds coverage in both Rust and Spark-side test suites.
Changes:
- Add Spark expression conversion for
StringInstrto call ext functionSpark_Instr. - Implement
Spark_Instras a DataFusion extension function (spark_instr) and register it in the native ext-function factory. - Add a Spark 3.3 test suite to validate
instrbehavior end-to-end.
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 6 comments.
| File | Description |
|---|---|
| spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala | Routes Spark StringInstr to the native ext function Spark_Instr. |
| native-engine/datafusion-ext-functions/src/spark_instr.rs | Adds the Rust implementation of spark_instr plus unit tests. |
| native-engine/datafusion-ext-functions/src/lib.rs | Registers the new Spark_Instr function and includes the module. |
| auron-spark-tests/spark33/src/test/scala/org/apache/spark/sql/AuronInstrSuite.scala | Adds Spark-side tests for instr usage in projections/filters/etc. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
You can also share your feedback on Copilot code review. Take the survey.
| val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) | ||
|
|
||
| assert(result(0) == 7, "instr('hello world', 'world') should return 7") | ||
| assert(result(1) == 1, "instr('hello world', 'hello') should return 1") | ||
| assert(result(2) == 5, "instr('hello world', 'o') should return 5") | ||
| assert(result(3) == 0, "instr('hello world', 'z') should return 0") | ||
| assert(result(4) == 0, "instr(null, 'test') should return null") | ||
| assert(result(5) == 0, "instr('test', null) should return null") |
There was a problem hiding this comment.
The test collects instr(str, substr) results via row.getInt(0) and then asserts 0 for null-input cases, but instr(null, ...) / instr(..., null) should produce a NULL result. With the native implementation returning NULL for null inputs, getInt(0) will throw at runtime. Collect as Integer/java.lang.Integer (or check row.isNullAt(0)) and assert NULLs for these rows instead of 0.
| val result = df.selectExpr("instr(str, substr)").collect().map(_.getInt(0)) | |
| assert(result(0) == 7, "instr('hello world', 'world') should return 7") | |
| assert(result(1) == 1, "instr('hello world', 'hello') should return 1") | |
| assert(result(2) == 5, "instr('hello world', 'o') should return 5") | |
| assert(result(3) == 0, "instr('hello world', 'z') should return 0") | |
| assert(result(4) == 0, "instr(null, 'test') should return null") | |
| assert(result(5) == 0, "instr('test', null) should return null") | |
| val result = df.selectExpr("instr(str, substr)").collect() | |
| assert(result(0).getInt(0) == 7, "instr('hello world', 'world') should return 7") | |
| assert(result(1).getInt(0) == 1, "instr('hello world', 'hello') should return 1") | |
| assert(result(2).getInt(0) == 5, "instr('hello world', 'o') should return 5") | |
| assert(result(3).getInt(0) == 0, "instr('hello world', 'z') should return 0") | |
| assert(result(4).isNullAt(0), "instr(null, 'test') should return null") | |
| assert(result(5).isNullAt(0), "instr('test', null) should return null") |
| val result = df | ||
| .groupBy("substr") | ||
| .count() | ||
| .filter("count > 0") | ||
| .orderBy("substr") | ||
| .collect() | ||
|
|
||
| assert(result.length >= 1) |
There was a problem hiding this comment.
This test is titled "in group by" but it never evaluates instr (it only groups by substr and counts). As written, it doesn't validate that instr works correctly in the presence of aggregations / grouping. Consider incorporating instr(...) either in the grouping key, an aggregate expression, or a post-aggregation filter so the test actually exercises the new function in a group-by context.
| val result = df | |
| .groupBy("substr") | |
| .count() | |
| .filter("count > 0") | |
| .orderBy("substr") | |
| .collect() | |
| assert(result.length >= 1) | |
| df.createOrReplaceTempView("instr_group_test") | |
| val result = spark.sql( | |
| """ | |
| |SELECT | |
| | substr, | |
| | SUM(CASE WHEN instr(str, substr) > 0 THEN 1 ELSE 0 END) AS match_count | |
| |FROM instr_group_test | |
| |GROUP BY substr | |
| |HAVING match_count > 0 | |
| |ORDER BY substr | |
| |""".stripMargin) | |
| .collect() | |
| // For the input data, only "test" appears within "str" values, and it occurs 3 times. | |
| assert(result.length == 1) | |
| assert(result(0).getString(0) == "test") | |
| assert(result(0).getLong(1) == 3L) |
| if substr.is_empty() { | ||
| Some(0) | ||
| } else { | ||
| Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) |
There was a problem hiding this comment.
(pos + 1) as i32 can overflow if the match position exceeds i32::MAX - 1, producing an incorrect (potentially negative) result. Since Spark's return type is Int32, use a checked/try conversion and decide on an explicit behavior (e.g., error or clamp to i32::MAX) rather than a lossy cast.
| Some(s.find(substr).map(|pos| (pos + 1) as i32).unwrap_or(0)) | |
| Some( | |
| s.find(substr) | |
| .map(|pos| { | |
| let one_based = pos.saturating_add(1); | |
| if one_based > i32::MAX as usize { | |
| i32::MAX | |
| } else { | |
| one_based as i32 | |
| } | |
| }) | |
| .unwrap_or(0), | |
| ) |
| let scalar = as_int32_array(&result_array)?.value(0); | ||
| Ok(ColumnarValue::Scalar(if result_array.is_null(0) { | ||
| ScalarValue::Int32(None) | ||
| } else { | ||
| ScalarValue::Int32(Some(scalar)) | ||
| })) |
There was a problem hiding this comment.
In the scalar-return branch, value(0) is read before checking is_null(0). Even if this currently works because Arrow stores a buffer value for nulls, it's easy to misread and can become unsafe if assumptions change. Consider checking is_null(0) first and only reading value(0) in the non-null case.
| let scalar = as_int32_array(&result_array)?.value(0); | |
| Ok(ColumnarValue::Scalar(if result_array.is_null(0) { | |
| ScalarValue::Int32(None) | |
| } else { | |
| ScalarValue::Int32(Some(scalar)) | |
| })) | |
| let int_array = as_int32_array(&result_array)?; | |
| let scalar = if int_array.is_null(0) { | |
| None | |
| } else { | |
| Some(int_array.value(0)) | |
| }; | |
| Ok(ColumnarValue::Scalar(ScalarValue::Int32(scalar))) |
| use arrow::array::{Array, ArrayRef, Int32Array, StringArray}; | ||
| use datafusion::{ | ||
| common::{ | ||
| Result, ScalarValue, | ||
| cast::{as_int32_array, as_string_array}, | ||
| }, | ||
| physical_plan::ColumnarValue, | ||
| }; |
There was a problem hiding this comment.
This module imports StringArray but doesn't use it in the production code (tests have their own imports). With CI running cargo clippy ... -D warnings, this unused import will fail the build. Remove the unused StringArray import (or use it explicitly if intended).
| use arrow::array::{ArrayRef, Int32Array, StringArray}; | ||
| use datafusion::{ | ||
| common::{Result, ScalarValue, cast::as_int32_array}, | ||
| physical_plan::ColumnarValue, | ||
| }; |
There was a problem hiding this comment.
The test module imports ArrayRef and Int32Array but doesn't use them. Since the repo runs clippy with -D warnings, these unused imports will fail CI. Remove the unused imports from the test module.
Which issue does this PR close?
Closes #2067
Rationale for this change
What changes are included in this PR?
Are there any user-facing changes?
How was this patch tested?