Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
344 changes: 344 additions & 0 deletions docs/streaming/structured-streaming-transform-with-state.md
Original file line number Diff line number Diff line change
Expand Up @@ -397,3 +397,347 @@ We also allow for composite type variables to be read in 2 formats:
Depending on your memory requirements, you can choose the format that best suits your use case.
More information about source options can be found [here](./structured-streaming-state-data-source.html).

## Testing StatefulProcessor with TwsTester

It may be useful to test your implementation of `StatefulProcessor` without having to run a
streaming query. For this, use a test helper called TwsTester (
[org.apache.spark.sql.streaming.TwsTester](https://github.com/apache/spark/blob/master/sql/core/src/main/scala/org/apache/spark/sql/streaming/TwsTester.scala)
for Scala/Java or
[pyspark.sql.streaming.tws_tester.TwsTester](https://github.com/apache/spark/blob/master/python/pyspark/sql/streaming/tws_tester.py)
for Python).

TwsTester is a wrapper around your StatefulProcessor that allows you to supply input data and
observe what rows would be produced by the TransformWithState operator.

Additionally, TwsTester supports the following functionality:
* Passing initial state - to test your implementation of `handleInitialState`.
* Observing and updating state.
* Manually advancing time to test timers (in ProcessingTime and EventTime modes). This can be used
to test your implementation of `handleExpiredTimer`.

### Limitations of TwsTester

TwsTester is designed to test your implementation of StatefulProcessor rather than Spark's
implementation of TransformWithState. It's recommended for prototyping or writing unit
tests, but not for end-to-end testing. These are features of real TransformWithState that TwsTester
does not model:
* TTL (Time-To-Live, evicting old states).
* Automatic watermark advancement. In EventTime mode, the watermark is computed for each batch based
on the event time column. It is used to expire timers and is passed to the StatefulProcessor. With
TwsTester, this is not happening and you must advance the watermark manually.


### TwsTester usage examples

The code below shows examples of using TwsTester in Scala and Python. It uses some example stateful
processors. You can find these processors
[here](https://github.com/apache/spark/blob/master/sql/core/src/test/scala/org/apache/spark/sql/streaming/TestProcessors.scala)
for Scala and
[here](https://github.com/apache/spark/blob/master/python/pyspark/sql/tests/pandas/helper/helper_pandas_transform_with_state.py)
for Python.


<div class="codetabs">
<div data-lang="python" markdown="1">

{% highlight python %}

from pyspark.sql.streaming import TwsTester
from pyspark.sql.types import Row
import pandas as pd

# Example 1: Basic usage with Row mode - testing input and output.
processor = RunningCountStatefulProcessor(use_pandas=False)
tester = TwsTester(processor)
result = tester.test("key1", [Row(value="a"), Row(value="b")])
assert result == [Row(key="key1", count=2)]

# Example 2: Basic usage with Pandas mode - testing input and output.
processor = RunningCountStatefulProcessor(use_pandas=True)
tester = TwsTester(processor)
input_df = pd.DataFrame({"value": ["a", "b"]})
result_df = tester.testInPandas("key1", input_df)
assert result_df["key"].tolist() == ["key1"]
assert result_df["count"].tolist() == [2]

# Example 3: Testing with initial state (Row mode).
processor = RunningCountStatefulProcessor(use_pandas=False)
tester = TwsTester(processor, initialStateRow=[
("a", Row(initial_count=10)),
("b", Row(initial_count=20))
])
result = tester.test("a", [Row(value="x")])
assert result == [Row(key="a", count=11)]

# Example 4: Testing with initial state (Pandas mode).
processor = RunningCountStatefulProcessor(use_pandas=True)
tester = TwsTester(processor, initialStatePandas=[
("a", pd.DataFrame({"initial_count": [10]})),
("b", pd.DataFrame({"initial_count": [20]}))
])
result_df = tester.testInPandas("a", pd.DataFrame({"value": ["x"]}))
assert result_df.set_index("key")["count"]["a"] == 11

# Example 5: Inspecting ValueState.
processor = RunningCountStatefulProcessor(use_pandas=False)
tester = TwsTester(processor)
tester.test("key1", [Row(value="a"), Row(value="b")])
assert tester.peekValueState("count", "key1") == (2,)
assert tester.peekValueState("count", "key3") is None

# Example 6: Manipulating ValueState directly.
processor = RunningCountStatefulProcessor(use_pandas=False)
tester = TwsTester(processor)
tester.updateValueState("count", "foo", (100,))
tester.test("foo", [Row(value="a")])
assert tester.peekValueState("count", "foo") == (101,)

# Example 7: Inspecting and manipulating ListState.
processor = TopKProcessor(k=3, use_pandas=False)
tester = TwsTester(processor)
tester.updateListState("topK", "key1", [(10.0,), (5.0,)])
tester.test("key1", [Row(score=7.0)])
assert tester.peekListState("topK", "key1") == [(10.0,), (7.0,), (5.0,)]

# Example 8: Inspecting and manipulating MapState.
processor = RowWordFrequencyProcessor()
tester = TwsTester(processor)
tester.updateMapState("frequencies", "user1", {("hello",): (5,), ("world",): (3,)})
tester.test("user1", [Row(word="hello"), Row(word="spark")])
state = tester.peekMapState("frequencies", "user1")
assert state[("hello",)] == (6,)
assert state[("spark",)] == (1,)

{% endhighlight %}
</div>


<div data-lang="scala" markdown="1">

{% highlight scala %}

import org.apache.spark.sql.streaming.{TwsTester, TimeMode}

// Example 1: Basic usage - testing input and output.
val processor = new RunningCountProcessor[String]()
val tester = new TwsTester[String, String, (String, Long)](processor)
val ans = tester.test("key1", List("a", "b"))
assert(ans == List(("key1", 2)))

// Example 2: Testing with initial state.
val processor = new RunningCountProcessor[String]()
val tester = new TwsTester(processor, initialState = List(("a", 10L), ("b", 20L)))
val ans = tester.test("a", List("x"))
assert(ans == List(("a", 11L)))

// Example 3: Inspecting ValueState.
val tester = new TwsTester(new RunningCountProcessor[String]())
tester.test("key1", List("a"))
assert(tester.peekValueState[Long]("count", "key1").get == 1L)
assert(tester.peekValueState[Long]("count", "key2").isEmpty)

// Example 4: Manipulating ValueState directly.
val tester = new TwsTester(new RunningCountProcessor[String]())
tester.updateValueState[Long]("count", "foo", 100)
tester.test("foo", List("a"))
assert(tester.peekValueState[Long]("count", "foo").get == 101L)

// Example 5: Inspecting and manipulating ListState.
val tester = new TwsTester(new TopKProcessor(3))
tester.updateListState("topK", "key1", List(10.0, 5.0))
val ans = tester.test("key1", List(("item", 7.0)))
assert(ans == List(("key1", 10.0), ("key1", 7.0), ("key1", 5.0)))
assert(tester.peekListState[Double]("topK", "key1") == List(10.0, 7.0, 5.0))

// Example 6: Inspecting and manipulating MapState.
val tester = new TwsTester(new WordFrequencyProcessor())
tester.updateMapState("frequencies", "user1", Map("hello" -> 5L, "world" -> 3L))
tester.test("user1", List(("", "hello"), ("", "spark")))
val state = tester.peekMapState[String, Long]("frequencies", "user1")
assert(state == Map("hello" -> 6L, "world" -> 3L, "spark" -> 1L))

{% endhighlight %}
</div>
</div>



### TwsTester example: step function

If you want to test how your StatefulProcessor updates state after processing a single row, you
might write a simple helper for this:

<div class="codetabs">
<div data-lang="python" markdown="1">

{% highlight python %}

processor = RunningCountStatefulProcessor(use_pandas=False)
tester = TwsTester(processor)

def step_function(key: str, input_row: str, state_in: int) -> int:
tester.updateValueState("count", key, (state_in,))
tester.test(key, [Row(value=input_row)])
return tester.peekValueState("count", key)[0]

assert step_function("key1", "a", 10) == 11

{% endhighlight %}
</div>


<div data-lang="scala" markdown="1">
{% highlight scala %}
val tester = new TwsTester(new RunningCountProcessor[String]())

def testStepFunction(key: String, input: String, stateIn: Long): Long = {
tester.updateValueState[Long]("count", key, stateIn)
tester.test(key, List(input))
tester.peekValueState[Long]("count", key).get
}

assert(testStepFunction("k", "x", 10L) == 11L)
{% endhighlight %}
</div>
</div>



### TwsTester example: row-by-row processing

`TwsTester.test` takes multiple rows for a single key to match the API of `handleInputRows`.
If you have a list of rows and you want them to be processed one by one (so that `handleInputRows`
always receives exactly one row), you can do that with a helper like this:


<div class="codetabs">
<div data-lang="python" markdown="1">

{% highlight python %}

processor = RunningCountStatefulProcessor(use_pandas=False)
tester = TwsTester(processor)
def test_row_by_row(input_rows):
return [output for row in input_rows for output in tester.test(row["key"], [row])]
output = test_row_by_row([Row(key="k1", value="a"), Row(key="k2", value="b"), Row(key="k1", value="c")])
assert output == [Row(key="k1", count=1), Row(key="k2", count=1), Row(key="k1", count=2)]

{% endhighlight %}
</div>


<div data-lang="scala" markdown="1">

{% highlight scala %}

val tester = new TwsTester(new RunningCountProcessor[String]())
def testRowByRow(input: List[(String, String)]): List[(String, Long)] = {
input.flatMap(row => tester.test(row._1, List(row._2)))
}
val output = testRowByRow(List(("k1", "a"), ("k2", "b"), ("k1", "c")))
assert(output == List(("k1", 1L), ("k2", 1L), ("k1", 2L)))

{% endhighlight %}
</div>
</div>


### TwsTester example: batch simulation

If you have a list of rows with different keys and you want to simulate how they will be grouped by
key and then passed to your StatefulProcessor, you can write a helper for that like this:

<div class="codetabs">
<div data-lang="python" markdown="1">

{% highlight python %}

from itertools import groupby
processor = RunningCountStatefulProcessor(use_pandas=False)
tester = TwsTester(processor)

def testBatch(input: list[Row], key_column_name: str = "key") -> list[Row]:
result: list[Row] = []
sorted_input = sorted(input, key=lambda row: row[key_column_name])
for key, rows in groupby(sorted_input, key=lambda row: row[key_column_name]):
result += tester.test(key, list(rows))
return result

batch1 = testBatch([Row(key="key1", value="a"), Row(key="key2", value="b"), Row(key="key1", value="c")])
assert batch1 == [Row(key="key1", count=2), Row(key="key2", count=1)]
batch2 = testBatch([Row(key="key1", value="c"), Row(key="key1", value="d")])
assert batch2 == [Row(key="key1", count=4)]

{% endhighlight %}
</div>


<div data-lang="scala" markdown="1">

{% highlight scala %}

val tester = new TwsTester(new RunningCountProcessor[String]())

def testBatch(input: List[(String, String)]): List[(String, Long)] =
input.groupBy(_._1).toList.flatMap { case (key, pairs) => tester.test(key, pairs.map(_._2)) }

val batch1 = testBatch(List(("key1", "a"), ("key2", "b")))
assert(batch1.sorted == List(("key1", 1L), ("key2", 1L)))
val batch2 = testBatch(List(("key1", "c"), ("key1", "d")))
assert(batch2 == List(("key1", 3L)))

{% endhighlight %}
</div>
</div>


### TwsTester example: timers

To test timers, create the TwsTester with `timeMode` set to `ProcessingTime` or `EventTime`. Use
`setProcessingTime` or `setWatermark` to advance time and fire expired timers.

<div class="codetabs">
<div data-lang="python" markdown="1">

{% highlight python %}

processor = SessionTimeoutProcessor(use_pandas=False)
tester = TwsTester(processor, timeMode="ProcessingTime")

# Process input - registers a timer at t=10000.
result1 = tester.test("key1", [Row(value="hello")])
assert result1 == [Row(key="key1", result="received:hello")]

# Advance time to 5000 - timer should NOT fire yet.
expired1 = tester.setProcessingTime(5000)
assert expired1 == []

# Advance time to 11000 - timer fires, handleExpiredTimer is called.
expired2 = tester.setProcessingTime(11000)
assert expired2 == [Row(key="key1", result="session-expired")]

{% endhighlight %}
</div>


<div data-lang="scala" markdown="1">
{% highlight scala %}
val tester = new TwsTester(
new SessionTimeoutProcessor(),
timeMode = TimeMode.ProcessingTime()
)

// Process input - registers a timer at t=10000.
val result1 = tester.test("key1", List("hello"))
assert(result1 == List(("key1", "received:hello")))

// Advance time to 5000 - timer should NOT fire yet.
val expired1 = tester.setProcessingTime(5000)
assert(expired1.isEmpty)

// Advance time to 11000 - timer fires, handleExpiredTimer is called.
val expired2 = tester.setProcessingTime(11000)
assert(expired2 == List(("key1", "session-expired@10000")))
{% endhighlight %}
</div>
</div>