Skip to content
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions python/pyspark/sql/connect/plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -1120,7 +1120,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
i = " " * indent
Expand Down Expand Up @@ -1213,7 +1213,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
assert self.left is not None
Expand Down Expand Up @@ -1288,7 +1288,7 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(**super().observations, **self.right.observations)
return {**super().observations, **self.right.observations}

def print(self, indent: int = 0) -> str:
i = " " * indent
Expand Down Expand Up @@ -1354,10 +1354,10 @@ def plan(self, session: "SparkConnectClient") -> proto.Relation:

@property
def observations(self) -> Dict[str, "Observation"]:
return dict(
return {
**super().observations,
**(self.other.observations if self.other is not None else {}),
)
}

def print(self, indent: int = 0) -> str:
assert self._child is not None
Expand Down Expand Up @@ -1664,7 +1664,7 @@ def observations(self) -> Dict[str, "Observation"]:
observations = {str(self._observation._name): self._observation}
else:
observations = {}
return dict(**super().observations, **observations)
return {**super().observations, **observations}


class NAFill(LogicalPlan):
Expand Down
95 changes: 94 additions & 1 deletion python/pyspark/sql/tests/connect/test_connect_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,21 @@
)
from pyspark.errors import PySparkValueError

from unittest.mock import MagicMock

if should_test_connect:
import pyspark.sql.connect.proto as proto
from pyspark.sql.connect.column import Column
from pyspark.sql.connect.dataframe import DataFrame
from pyspark.sql.connect.plan import WriteOperation, Read
from pyspark.sql.connect.plan import (
WriteOperation,
Read,
Join,
SetOperation,
CollectMetrics,
LogicalPlan,
)
from pyspark.sql.connect.observation import Observation
from pyspark.sql.connect.readwriter import DataFrameReader
from pyspark.sql.connect.expressions import LiteralExpression
from pyspark.sql.connect.functions import col, lit, max, min, sum
Expand Down Expand Up @@ -1131,6 +1141,89 @@ def test_literal_to_any_conversion(self):
LiteralExpression._to_value(proto_lit, DoubleType)


if should_test_connect:

class _StubPlan(LogicalPlan):
"""Minimal LogicalPlan that returns a fixed observations dict."""

def __init__(self, observations=None):
super().__init__(None)
self._obs = observations or {}

@property
def observations(self):
return self._obs

def plan(self, session):
raise NotImplementedError

def print(self, indent=0):
return ""


@unittest.skipIf(not should_test_connect, connect_requirement_message)
class TestObservationMerging(unittest.TestCase):
"""Verify that observations are deduplicated when plan branches share the same key."""

def test_join_with_duplicate_observation_names(self):
obs = MagicMock()
obs._name = "shared"
shared = {"shared": obs}

left = _StubPlan(observations=shared)
right = _StubPlan(observations=shared)

join = Join.__new__(Join)
join._child = left
join.right = right

result = join.observations
self.assertEqual(result, {"shared": obs})

def test_join_with_distinct_observations(self):
obs_a = MagicMock()
obs_a._name = "a"
obs_b = MagicMock()
obs_b._name = "b"

left = _StubPlan(observations={"a": obs_a})
right = _StubPlan(observations={"b": obs_b})

join = Join.__new__(Join)
join._child = left
join.right = right

result = join.observations
self.assertEqual(result, {"a": obs_a, "b": obs_b})

def test_set_operation_with_duplicate_observation_names(self):
obs = MagicMock()
obs._name = "shared"
shared = {"shared": obs}

left = _StubPlan(observations=shared)
right = _StubPlan(observations=shared)

set_op = SetOperation.__new__(SetOperation)
set_op._child = left
set_op.other = right

result = set_op.observations
self.assertEqual(result, {"shared": obs})

def test_collect_metrics_with_duplicate_observation_name(self):
obs = Observation("my_metric")
parent = _StubPlan(observations={"my_metric": obs})

cm = CollectMetrics.__new__(CollectMetrics)
cm._child = parent
cm._observation = obs
cm._exprs = []

result = cm.observations
self.assertEqual(result, {"my_metric": obs})


if __name__ == "__main__":
from pyspark.testing import main

Expand Down
113 changes: 113 additions & 0 deletions python/pyspark/sql/tests/test_observation.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,16 @@ def test_observe(self):
messageParameters={},
)

new_observation = Observation("new")
Comment thread
mwojtyczka marked this conversation as resolved.
Outdated
with self.assertRaises(PySparkAssertionError) as pe:
Comment thread
mwojtyczka marked this conversation as resolved.
Outdated
df.observe(new_observation, 2 * F.count(F.lit(1)).alias("cnt"))
Comment thread
mwojtyczka marked this conversation as resolved.
Outdated

self.check_error(
exception=pe.exception,
errorClass="DUPLICATED_METRICS_NAME",
messageParameters={"metricName": "new"},
)

# observation requires name (if given) to be non empty string
with self.assertRaisesRegex(PySparkTypeError, "`name` should be str, got int"):
Observation(123)
Expand Down Expand Up @@ -263,6 +273,109 @@ def test_observation_errors_propagated_to_client(self):

self.assertIn("test error", str(cm.exception))

def test_observe_self_join(self):
# SPARK-56322: self-joining an observed DataFrame
obs = Observation("my_observation")
df = (
self.spark.range(100)
.selectExpr("id", "CASE WHEN id < 10 THEN 'A' ELSE 'B' END AS group_key")
.observe(obs, F.count(F.lit(1)).alias("row_count"))
)

df1 = df.where("id < 20")
df2 = df.where("id % 2 == 0")

joined = df1.alias("a").join(df2.alias("b"), on=["id"], how="inner")
result = joined.collect()

# The join should produce rows where id < 20 AND id is even
expected_ids = sorted([i for i in range(20) if i % 2 == 0])
actual_ids = sorted([row.id for row in result])
self.assertEqual(actual_ids, expected_ids)

# The observation should have been collected
self.assertEqual(obs.get, {"row_count": 100})

# Check the error conditions
with self.assertRaises(PySparkAssertionError) as pe:
joined.observe(obs, F.count(F.lit(1)).alias("row_count")).collect()

self.check_error(
exception=pe.exception,
errorClass="REUSE_OBSERVATION",
messageParameters={},
)

obs2 = Observation("12345")
Comment thread
mwojtyczka marked this conversation as resolved.
Outdated
with self.assertRaises(PySparkAssertionError) as pe:
joined.observe(obs2, 2 * F.count(F.lit(1)).alias("row_count")).collect()

self.check_error(
exception=pe.exception,
errorClass="DUPLICATED_METRICS_NAME",
messageParameters={"metricName": "12345"},
)

def test_observe_lateral_join(self):
# SPARK-56322: lateral self-joining an observed DataFrame
obs = Observation("lateral_join_obs")
df = self.spark.range(50).observe(obs, F.count(F.lit(1)).alias("row_count"))

joined = (
df.alias("left")
.lateralJoin(
df.alias("right"), on=F.expr("right.id between left.id - 1 and left.id + 1")
)
.selectExpr("left.id as left_id", "right.id as right_id")
)
result = joined.collect()

# Joins on row 0 should produce rows 0 and 1
bounded_matches = sorted([r.right_id for r in result if r.left_id == 0])
self.assertEqual(bounded_matches, [0, 1])

# Joins on row 25 should produce rows 24, 25, and 26
unbounded_matches = sorted([r.right_id for r in result if r.left_id == 25])
self.assertEqual(unbounded_matches, [24, 25, 26])

# The observation should have been collected
self.assertEqual(obs.get, {"row_count": 50})

# Check the error conditions
with self.assertRaises(PySparkAssertionError) as reused:
joined.observe(obs, F.count(F.lit(1)).alias("row_count")).collect()

self.check_error(
exception=reused.exception,
errorClass="REUSE_OBSERVATION",
messageParameters={},
)

obs2 = Observation("12345")
Comment thread
mwojtyczka marked this conversation as resolved.
Outdated
with self.assertRaises(PySparkAssertionError) as pe:
joined.observe(obs2, F.count(2 * F.lit(1)).alias("row_count")).collect()

self.check_error(
exception=pe.exception,
errorClass="DUPLICATED_METRICS_NAME",
messageParameters={"metricName": "12345"},
)

def test_observe_self_join_union(self):
# SPARK-56322: union of observed DataFrames with same observation
obs = Observation("union_obs")
df = self.spark.range(50).observe(obs, F.count(F.lit(1)).alias("cnt"))

df1 = df.where("id < 25")
df2 = df.where("id >= 25")

unioned = df1.union(df2)
result = unioned.collect()

actual_ids = sorted([row.id for row in result])
self.assertEqual(actual_ids, list(range(50)))
self.assertEqual(obs.get, {"cnt": 50})


class DataFrameObservationTests(
DataFrameObservationTestsMixin,
Expand Down