diff --git a/docs/streaming/structured-streaming-transform-with-state.md b/docs/streaming/structured-streaming-transform-with-state.md index 050337f2dd777..c74c3a0909c31 100644 --- a/docs/streaming/structured-streaming-transform-with-state.md +++ b/docs/streaming/structured-streaming-transform-with-state.md @@ -397,3 +397,347 @@ We also allow for composite type variables to be read in 2 formats: Depending on your memory requirements, you can choose the format that best suits your use case. More information about source options can be found [here](./structured-streaming-state-data-source.html). +## Testing StatefulProcessor with TwsTester + +It may be useful to test your implementation of `StatefulProcessor` without having to run a +streaming query. For this, use a test helper called TwsTester ( +[org.apache.spark.sql.streaming.TwsTester](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala) +for Scala/Java or +[pyspark.sql.streaming.tws_tester.TwsTester](https://github.com/apache/spark/blob/master/python/pyspark/sql/streaming/tws_tester.py) +for Python). + +TwsTester is a wrapper around your StatefulProcessor that allows you to supply input data and +observe what rows would be produced by the TransformWithState operator. + +Additionally, TwsTester supports the following functionality: +* Passing initial state - to test your implementation of `handleInitialState`. +* Observing and updating state. +* Manually advancing time to test timers (in ProcessingTime and EventTime modes). This can be used + to test your implementation of `handleExpiredTimer`. + +### Limitations of TwsTester + +TwsTester is designed to test your implementation of StatefulProcessor rather than Spark's +implementation of TransformWithState. It's recommended for prototyping or writing unit +tests, but not for end-to-end testing. These are features of real TransformWithState that TwsTester +does not model: +* TTL (Time-To-Live, evicting old states). +* Automatic watermark advancement. In EventTime mode, the watermark is computed for each batch based + on the event time column. It is used to expire timers and is passed to the StatefulProcessor. With + TwsTester, this is not happening and you must advance the watermark manually. + + +### TwsTester usage examples + +The code below shows examples of using TwsTester in Scala and Python. It uses some example stateful +processors. You can find these processors +[here](https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/streaming/TestProcessors.scala) +for Scala and +[here](https://github.com/apache/spark/blob/master/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py) +for Python. + + +
+
+ +{% highlight python %} + +from pyspark.sql.streaming import TwsTester +from pyspark.sql.types import Row +import pandas as pd + +# Example 1: Basic usage with Row mode - testing input and output. +processor = RunningCountStatefulProcessor(use_pandas=False) +tester = TwsTester(processor) +result = tester.test("key1", [Row(value="a"), Row(value="b")]) +assert result == [Row(key="key1", count=2)] + +# Example 2: Basic usage with Pandas mode - testing input and output. +processor = RunningCountStatefulProcessor(use_pandas=True) +tester = TwsTester(processor) +input_df = pd.DataFrame({"value": ["a", "b"]}) +result_df = tester.testInPandas("key1", input_df) +assert result_df["key"].tolist() == ["key1"] +assert result_df["count"].tolist() == [2] + +# Example 3: Testing with initial state (Row mode). +processor = RunningCountStatefulProcessor(use_pandas=False) +tester = TwsTester(processor, initialStateRow=[ + ("a", Row(initial_count=10)), + ("b", Row(initial_count=20)) +]) +result = tester.test("a", [Row(value="x")]) +assert result == [Row(key="a", count=11)] + +# Example 4: Testing with initial state (Pandas mode). +processor = RunningCountStatefulProcessor(use_pandas=True) +tester = TwsTester(processor, initialStatePandas=[ + ("a", pd.DataFrame({"initial_count": [10]})), + ("b", pd.DataFrame({"initial_count": [20]})) +]) +result_df = tester.testInPandas("a", pd.DataFrame({"value": ["x"]})) +assert result_df.set_index("key")["count"]["a"] == 11 + +# Example 5: Inspecting ValueState. +processor = RunningCountStatefulProcessor(use_pandas=False) +tester = TwsTester(processor) +tester.test("key1", [Row(value="a"), Row(value="b")]) +assert tester.peekValueState("count", "key1") == (2,) +assert tester.peekValueState("count", "key3") is None + +# Example 6: Manipulating ValueState directly. +processor = RunningCountStatefulProcessor(use_pandas=False) +tester = TwsTester(processor) +tester.updateValueState("count", "foo", (100,)) +tester.test("foo", [Row(value="a")]) +assert tester.peekValueState("count", "foo") == (101,) + +# Example 7: Inspecting and manipulating ListState. +processor = TopKProcessor(k=3, use_pandas=False) +tester = TwsTester(processor) +tester.updateListState("topK", "key1", [(10.0,), (5.0,)]) +tester.test("key1", [Row(score=7.0)]) +assert tester.peekListState("topK", "key1") == [(10.0,), (7.0,), (5.0,)] + +# Example 8: Inspecting and manipulating MapState. +processor = RowWordFrequencyProcessor() +tester = TwsTester(processor) +tester.updateMapState("frequencies", "user1", {("hello",): (5,), ("world",): (3,)}) +tester.test("user1", [Row(word="hello"), Row(word="spark")]) +state = tester.peekMapState("frequencies", "user1") +assert state[("hello",)] == (6,) +assert state[("spark",)] == (1,) + +{% endhighlight %} +
+ + +
+ +{% highlight scala %} + +import org.apache.spark.sql.streaming.{TwsTester, TimeMode} + +// Example 1: Basic usage - testing input and output. +val processor = new RunningCountProcessor[String]() +val tester = new TwsTester[String, String, (String, Long)](processor) +val ans = tester.test("key1", List("a", "b")) +assert(ans == List(("key1", 2))) + +// Example 2: Testing with initial state. +val processor = new RunningCountProcessor[String]() +val tester = new TwsTester(processor, initialState = List(("a", 10L), ("b", 20L))) +val ans = tester.test("a", List("x")) +assert(ans == List(("a", 11L))) + +// Example 3: Inspecting ValueState. +val tester = new TwsTester(new RunningCountProcessor[String]()) +tester.test("key1", List("a")) +assert(tester.peekValueState[Long]("count", "key1").get == 1L) +assert(tester.peekValueState[Long]("count", "key2").isEmpty) + +// Example 4: Manipulating ValueState directly. +val tester = new TwsTester(new RunningCountProcessor[String]()) +tester.updateValueState[Long]("count", "foo", 100) +tester.test("foo", List("a")) +assert(tester.peekValueState[Long]("count", "foo").get == 101L) + +// Example 5: Inspecting and manipulating ListState. +val tester = new TwsTester(new TopKProcessor(3)) +tester.updateListState("topK", "key1", List(10.0, 5.0)) +val ans = tester.test("key1", List(("item", 7.0))) +assert(ans == List(("key1", 10.0), ("key1", 7.0), ("key1", 5.0))) +assert(tester.peekListState[Double]("topK", "key1") == List(10.0, 7.0, 5.0)) + +// Example 6: Inspecting and manipulating MapState. +val tester = new TwsTester(new WordFrequencyProcessor()) +tester.updateMapState("frequencies", "user1", Map("hello" -> 5L, "world" -> 3L)) +tester.test("user1", List(("", "hello"), ("", "spark"))) +val state = tester.peekMapState[String, Long]("frequencies", "user1") +assert(state == Map("hello" -> 6L, "world" -> 3L, "spark" -> 1L)) + +{% endhighlight %} +
+
+ + + +### TwsTester example: step function + +If you want to test how your StatefulProcessor updates state after processing a single row, you +might write a simple helper for this: + +
+
+ +{% highlight python %} + +processor = RunningCountStatefulProcessor(use_pandas=False) +tester = TwsTester(processor) + +def step_function(key: str, input_row: str, state_in: int) -> int: + tester.updateValueState("count", key, (state_in,)) + tester.test(key, [Row(value=input_row)]) + return tester.peekValueState("count", key)[0] + +assert step_function("key1", "a", 10) == 11 + +{% endhighlight %} +
+ + +
+{% highlight scala %} +val tester = new TwsTester(new RunningCountProcessor[String]()) + +def testStepFunction(key: String, input: String, stateIn: Long): Long = { + tester.updateValueState[Long]("count", key, stateIn) + tester.test(key, List(input)) + tester.peekValueState[Long]("count", key).get +} + +assert(testStepFunction("k", "x", 10L) == 11L) +{% endhighlight %} +
+
+ + + +### TwsTester example: row-by-row processing + +`TwsTester.test` takes multiple rows for a single key to match the API of `handleInputRows`. +If you have a list of rows and you want them to be processed one by one (so that `handleInputRows` +always receives exactly one row), you can do that with a helper like this: + + +
+
+ +{% highlight python %} + +processor = RunningCountStatefulProcessor(use_pandas=False) +tester = TwsTester(processor) +def test_row_by_row(input_rows): + return [output for row in input_rows for output in tester.test(row["key"], [row])] +output = test_row_by_row([Row(key="k1", value="a"), Row(key="k2", value="b"), Row(key="k1", value="c")]) +assert output == [Row(key="k1", count=1), Row(key="k2", count=1), Row(key="k1", count=2)] + +{% endhighlight %} +
+ + +
+ +{% highlight scala %} + +val tester = new TwsTester(new RunningCountProcessor[String]()) +def testRowByRow(input: List[(String, String)]): List[(String, Long)] = { + input.flatMap(row => tester.test(row._1, List(row._2))) +} +val output = testRowByRow(List(("k1", "a"), ("k2", "b"), ("k1", "c"))) +assert(output == List(("k1", 1L), ("k2", 1L), ("k1", 2L))) + +{% endhighlight %} +
+
+ + +### TwsTester example: batch simulation + +If you have a list of rows with different keys and you want to simulate how they will be grouped by +key and then passed to your StatefulProcessor, you can write a helper for that like this: + +
+
+ +{% highlight python %} + +from itertools import groupby +processor = RunningCountStatefulProcessor(use_pandas=False) +tester = TwsTester(processor) + +def testBatch(input: list[Row], key_column_name: str = "key") -> list[Row]: + result: list[Row] = [] + sorted_input = sorted(input, key=lambda row: row[key_column_name]) + for key, rows in groupby(sorted_input, key=lambda row: row[key_column_name]): + result += tester.test(key, list(rows)) + return result + +batch1 = testBatch([Row(key="key1", value="a"), Row(key="key2", value="b"), Row(key="key1", value="c")]) +assert batch1 == [Row(key="key1", count=2), Row(key="key2", count=1)] +batch2 = testBatch([Row(key="key1", value="c"), Row(key="key1", value="d")]) +assert batch2 == [Row(key="key1", count=4)] + +{% endhighlight %} +
+ + +
+ +{% highlight scala %} + +val tester = new TwsTester(new RunningCountProcessor[String]()) + +def testBatch(input: List[(String, String)]): List[(String, Long)] = + input.groupBy(_._1).toList.flatMap { case (key, pairs) => tester.test(key, pairs.map(_._2)) } + +val batch1 = testBatch(List(("key1", "a"), ("key2", "b"))) +assert(batch1.sorted == List(("key1", 1L), ("key2", 1L))) +val batch2 = testBatch(List(("key1", "c"), ("key1", "d"))) +assert(batch2 == List(("key1", 3L))) + +{% endhighlight %} +
+
+ + +### TwsTester example: timers + +To test timers, create the TwsTester with `timeMode` set to `ProcessingTime` or `EventTime`. Use +`setProcessingTime` or `setWatermark` to advance time and fire expired timers. + +
+
+ +{% highlight python %} + +processor = SessionTimeoutProcessor(use_pandas=False) +tester = TwsTester(processor, timeMode="ProcessingTime") + +# Process input - registers a timer at t=10000. +result1 = tester.test("key1", [Row(value="hello")]) +assert result1 == [Row(key="key1", result="received:hello")] + +# Advance time to 5000 - timer should NOT fire yet. +expired1 = tester.setProcessingTime(5000) +assert expired1 == [] + +# Advance time to 11000 - timer fires, handleExpiredTimer is called. +expired2 = tester.setProcessingTime(11000) +assert expired2 == [Row(key="key1", result="session-expired")] + +{% endhighlight %} +
+ + +
+{% highlight scala %} +val tester = new TwsTester( + new SessionTimeoutProcessor(), + timeMode = TimeMode.ProcessingTime() +) + +// Process input - registers a timer at t=10000. +val result1 = tester.test("key1", List("hello")) +assert(result1 == List(("key1", "received:hello"))) + +// Advance time to 5000 - timer should NOT fire yet. +val expired1 = tester.setProcessingTime(5000) +assert(expired1.isEmpty) + +// Advance time to 11000 - timer fires, handleExpiredTimer is called. +val expired2 = tester.setProcessingTime(11000) +assert(expired2 == List(("key1", "session-expired@10000"))) +{% endhighlight %} +
+