Skip to content

[SPARK-56322][CONNECT][PYTHON] Fix TypeError when self-joining observed DataFrames#55140

Open
mwojtyczka wants to merge 8 commits intoapache:masterfrom
mwojtyczka:fix-observation-self-join
Open

[SPARK-56322][CONNECT][PYTHON] Fix TypeError when self-joining observed DataFrames#55140
mwojtyczka wants to merge 8 commits intoapache:masterfrom
mwojtyczka:fix-observation-self-join

Conversation

@mwojtyczka
Copy link
Copy Markdown

@mwojtyczka mwojtyczka commented Apr 1, 2026

What changes were proposed in this pull request?

Fixing bug: https://issues.apache.org/jira/browse/SPARK-56322

Replace dict(**a, **b) with {**a, **b} dict literal syntax when merging observations across plan branches in Join, AsOfJoin, LateralJoin, SetOperation, and CollectMetrics.

Why are the changes needed?

When a DataFrame with .observe() is filtered into two subsets and then self-joined, both branches of the join carry the same Observation instance under the same name. The observations property merges left and right observations using dict(**left, **right), which raises TypeError when both dicts contain the same key:

TypeError: dict() got multiple values for keyword argument 'my_observation'

This is a Python semantics issue — dict(**a, **b) treats each key as a keyword argument, and Python does not allow duplicate keyword arguments. The dict literal {**a, **b} does not have this restriction and silently lets the last value win, which is correct here since both values are the same Observation instance originating from the same .observe() call.

Why "last value wins" is safe here: When a DataFrame is observed and then branched (filtered, aliased), both branches inherit a reference to the same Observation instance. The duplicate keys in the merge always map to the identical Python object — there is no scenario where two different Observation instances share the same name within a single plan tree. Therefore, deduplication does not lose any data.

This pattern affects any workflow that:

  • Observes a DataFrame (e.g., for monitoring row counts or data quality metrics)
  • Filters or transforms it into multiple subsets
  • Joins the subsets back together

This is common in data quality pipelines (split into valid/invalid rows, then rejoin) and ETL workflows that branch and merge.

How to reproduce

from pyspark.sql import Observation
from pyspark.sql.functions import count, lit

obs = Observation("my_observation")
df = (
    spark.range(100)
    .selectExpr("id", "case when id < 10 then 'A' else 'B' end as group_key")
    .observe(obs, count(lit(1)).alias("row_count"))
)

# Filter into two subsets — both carry the same observation
df1 = df.where("id < 20")
df2 = df.where("id % 2 == 0")

# Self-join triggers the bug
joined = df1.alias("a").join(df2.alias("b"), on=["id"], how="inner")
joined.collect()
# TypeError: dict() got multiple values for keyword argument 'my_observation'

Does this PR introduce any user-facing change?

Yes — self-joining an observed DataFrame no longer raises TypeError. No behavior change for joins where observations don't overlap (the common case).

How was this patch tested?

Added unit tests in test_connect_plan.py covering:

  • Join with duplicate observation names (the reported scenario)
  • Join with distinct observations (regression check)
  • SetOperation with duplicate observation names
  • CollectMetrics with parent sharing the same observation name

All tests fail on the unpatched code with the exact TypeError and pass with the fix.

Was this patch authored or co-authored using generative AI tooling?

Yes, Claude Code for testing

@mwojtyczka mwojtyczka force-pushed the fix-observation-self-join branch from 0cf0cc8 to 530cc24 Compare April 1, 2026 14:52
@mwojtyczka mwojtyczka changed the title [SPARK-XXXXX][CONNECT][PYTHON] Fix TypeError when self-joining observed DataFrames [SPARK-56322][CONNECT][PYTHON] Fix TypeError when self-joining observed DataFrames Apr 1, 2026
…ed DataFrames

When a DataFrame with .observe() is filtered and then self-joined, both
branches carry the same Observation under the same name. The observations
property merged them using dict(**left, **right), which raises TypeError
on duplicate keyword arguments.

Replace dict(**a, **b) with {**a, **b} dict literal syntax in Join,
AsOfJoin, LateralJoin, SetOperation, and CollectMetrics. Dict literals
handle duplicate keys by letting the last value win, which is correct
here since both values are the same Observation instance.

Closes apache#55140
@ueshin
Copy link
Copy Markdown
Member

ueshin commented Apr 1, 2026

cc @hvanhovell What's the expected behavior if the same Observation is in the both sides of joins?

  • 100 because df is observed
  • 200 because df appears on both sides
  • invalid plan because each Observation should appear only once in the whole plan, e.g.
>>> df.observe(obs, count(lit(1).alias("row_count")))
Traceback (most recent call last):
...
pyspark.errors.exceptions.base.PySparkAssertionError: [REUSE_OBSERVATION] An Observation can be used with a DataFrame only once.

Currently classic returns 100, but I'm not sure it's correct or now.

UPDATE: nvm, the other case's error message suggested it's ok to appear multiple times if it's self-join:

>>> obs2 = Observation("12345")
>>> df.observe(obs2, count(lit(1).alias("row_count")))
Traceback (most recent call last):
...
pyspark.errors.exceptions.captured.AnalysisException: [DUPLICATED_METRICS_NAME] The metric name is not unique: 12345. The same name cannot be used for metrics with different results.
However multiple instances of metrics with with same result and name are allowed (e.g. self-joins). SQLSTATE: 42710;
CollectMetrics 12345, [count(1) AS count(1 AS row_count)#7L], 2
+- CollectMetrics 12345, [count(1) AS count(1 AS row_count)#4L], 1
   +- Project [id#0L, CASE WHEN (id#0L < cast(10 as bigint)) THEN A ELSE B END AS group_key#1]
      +- Range (0, 100, step=1, splits=Some(16))

However multiple instances of metrics with with same result and name are allowed (e.g. self-joins).

@ueshin
Copy link
Copy Markdown
Member

ueshin commented Apr 1, 2026

@mwojtyczka btw, could you add tests to python/pyspark/sql/tests/test_observation.py to see if the behavior is same between classic and connect?

Copy link
Copy Markdown
Member

@ueshin ueshin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, pending to add a test in test_observation.py.

@ueshin ueshin requested a review from hvanhovell April 1, 2026 23:58
@ueshin
Copy link
Copy Markdown
Member

ueshin commented Apr 2, 2026

@mwojtyczka Could you also enable Github Actions to run the CI?
Please follow the instruction at https://github.com/apache/spark/pull/55140/checks?check_run_id=69566409408. Thanks.

@mwojtyczka
Copy link
Copy Markdown
Author

mwojtyczka commented Apr 2, 2026

@mwojtyczka btw, could you add tests to python/pyspark/sql/tests/test_observation.py to see if the behavior is same between classic and connect?

thank you for the review!

  • Enabled workflows
  • Added tests for classic. They pass, so it's connect-only issue.

@ghanse
Copy link
Copy Markdown

ghanse commented Apr 2, 2026

Added a test to cover LateralJoin nodes.

@ueshin
Copy link
Copy Markdown
Member

ueshin commented Apr 2, 2026

@mwojtyczka @ghanse Could you also add some more tests from my previous comment:

For:

obs = Observation("my_observation")
df = (
    spark.range(100)
    .selectExpr("id", "case when id < 10 then 'A' else 'B' end as group_key")
    .observe(obs, count(lit(1)).alias("row_count"))
)
  1. use obs twice:
df.observe(obs, count(lit(1).alias("row_count"))).collect()
  1. use another Observation, but with the same name:
obs2 = Observation("12345")
df.observe(obs2, count(lit(1).alias("row_count"))).collect()

both should raise exceptions above.

messageParameters={},
)

new_observation = Observation("new")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name should be metric if checking DUPLICATED_METRICS_NAME.

)

new_observation = Observation("new")
with self.assertRaises(PySparkAssertionError) as pe:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The error should be AnalysisException for DUPLICATED_METRICS_NAME.


new_observation = Observation("new")
with self.assertRaises(PySparkAssertionError) as pe:
df.observe(new_observation, 2 * F.count(F.lit(1)).alias("cnt"))
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
df.observe(new_observation, 2 * F.count(F.lit(1)).alias("cnt"))
observed.observe(new_observation, 2 * F.count(F.lit(1)).alias("cnt")).collect()

messageParameters={},
)

obs2 = Observation("12345")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
obs2 = Observation("12345")
obs2 = Observation("my_observation")

messageParameters={},
)

obs2 = Observation("12345")
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
obs2 = Observation("12345")
obs2 = Observation("lateral_join_obs")

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants