Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ Chris Mahoney
Chris Lamb
Chris NeJame
Chris Rose
Chris Shucksmith (shuckc)
Chris Wheeler
Christian Boelsen
Christian Clauss
Expand Down
1 change: 1 addition & 0 deletions changelog/14445.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed duplicate evaluation of walrus assignments after assert statement rewrite.
86 changes: 51 additions & 35 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -869,6 +869,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
self.statements: list[ast.stmt] = []
self.variables: list[str] = []
self.variable_counter = itertools.count()
# Clear walrus overwrite tracking — only valid within a single assert.
self.variables_overwrite[self.scope] = {}

if self.enable_assertion_pass_hook:
self.format_variables: list[str] = []
Expand Down Expand Up @@ -964,15 +966,16 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
# Return the NamedExpr node itself to preserve in-place evaluation order.
# For the display (used in failure messages), reference the target variable
# rather than re-evaluating the NamedExpr.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id
target_name = ast.Name(target_id, ast.Load())
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
dorepr = self.helper("_should_repr_global_name", target_name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id))
return name, self.explanation_param(expr)

def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]:
Expand All @@ -998,20 +1001,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
for i, v in enumerate(boolop.values):
if i:
fail_inner: list[ast.stmt] = []
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
# expl_cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821
self.expl_stmts = fail_inner
match v:
# Check if the left operand is an ast.NamedExpr and the value has already been visited
case ast.Compare(
left=ast.NamedExpr(target=ast.Name(id=target_id))
) if target_id in [
e.id for e in boolop.values[:i] if hasattr(e, "id")
]:
pytest_temp = self.variable()
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
# mypy's false positive, we're checking that the 'target' attribute exists.
v.left.target.id = pytest_temp # type:ignore[attr-defined]
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
Expand All @@ -1022,6 +1014,12 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
cond: ast.expr = res
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
# Save the condition value for the explanation path. A walrus
# in a later operand may modify the variable, but the saved
# value preserves the original truthiness for display purposes.
expl_cond_var = self.variable()
body.append(ast.Assign([ast.Name(expl_cond_var, ast.Store())], cond))
expl_cond: ast.expr = ast.Name(expl_cond_var, ast.Load())
inner: list[ast.stmt] = []
self.statements.append(ast.If(cond, inner, []))
self.statements = body = inner
Expand Down Expand Up @@ -1100,15 +1098,21 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:

def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
match comp.left:
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
self.scope, {}
):
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
case ast.NamedExpr(target=ast.Name(id=target_id)):
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, ast.NamedExpr):
# Hoist the NamedExpr into a temp variable BEFORE visiting comparators,
# so that comparators referencing the walrus target see the assigned value.
# The assign evaluates (target := value), storing result in @py_assertN.
target_id = comp.left.target.id
left_res = self.assign(comp.left)
target_name = ast.Name(target_id, ast.Load())
locs = ast.Call(self.builtin("locals"), [], [])
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", target_name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(left_res), ast.Constant(target_id))
left_expl = self.explanation_param(expr)
else:
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, ast.Compare | ast.BoolOp):
left_expl = f"({left_expl})"
res_variables = [self.variable() for i in range(len(comp.ops))]
Expand All @@ -1119,18 +1123,30 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms: list[ast.expr] = []
results = [left_res]
for i, op, next_operand in it:
match (next_operand, left_res):
case (
ast.NamedExpr(target=ast.Name(id=target_id)),
ast.Name(id=name_id),
) if target_id == name_id:
next_operand.target.id = self.variable()
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
# If the comparator is a NamedExpr that overwrites the left operand's
# variable, save the left value in a temp BEFORE the comparison so
# the failure message can display the pre-walrus value.
if (
isinstance(next_operand, ast.NamedExpr)
and isinstance(left_res, ast.Name)
and next_operand.target.id == left_res.id
):
saved_left = self.variable()
self.statements.append(
ast.Assign([ast.Name(saved_left, ast.Store())], left_res)
)
results[-1] = ast.Name(saved_left, ast.Load())

next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, ast.Compare | ast.BoolOp):
next_expl = f"({next_expl})"
results.append(next_res)
# For NamedExpr comparators, use the target variable in results
# (for the failure message) instead of the NamedExpr node itself,
# to avoid re-evaluating the walrus operator in the failure path.
if isinstance(next_operand, ast.NamedExpr):
results.append(ast.Name(next_operand.target.id, ast.Load()))
else:
results.append(next_res)
sym = BINOP_MAP[op.__class__]
syms.append(ast.Constant(sym))
expl = f"{left_expl} {sym} {next_expl}"
Expand Down
55 changes: 53 additions & 2 deletions testing/test_assertrewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value():
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and False is False)"])
result.stdout.fnmatch_lines(["*assert not (False and False is False)"])

def test_assertion_walrus_operator_boolean_none_fails(
self, pytester: Pytester
Expand All @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value():
)
result = pytester.runpytest()
assert result.ret == 1
result.stdout.fnmatch_lines(["*assert not (True and None is None)"])
result.stdout.fnmatch_lines(["*assert not (None and None is None)"])

def test_assertion_walrus_operator_value_changes_cleared_after_each_test(
self, pytester: Pytester
Expand Down Expand Up @@ -1846,6 +1846,57 @@ def test_2():
assert result.ret == 0


@pytest.mark.skipif(
sys.version_info < (3, 8), reason="walrus operator not available in py<38"
)
class TestIssue14445WalrusDoubleEval:
"""Test that walrus operator (:=) is not evaluated twice by assertion rewriting.

The rewriter must not re-evaluate NamedExpr nodes when building the
failure message, as that causes side effects to fire twice.
"""

def test_walrus_operator_not_double_evaluated(self, pytester: Pytester) -> None:
"""Walrus assigns wrong value when rewriter evaluates := twice."""
pytester.makepyfile(
"""
class Counter:
def __init__(self):
self.value = 0

def increment(self):
self.value += 1

def test_walrus_basic():
c = Counter()
assert (before := c.value) == 0
c.increment()
assert before != (after := c.value)
"""
)
result = pytester.runpytest()
assert result.ret == 0

def test_walrus_operator_cumulative_not_doubled(self, pytester: Pytester) -> None:
"""Cumulative walrus increments should not fire twice."""
pytester.makepyfile(
"""
def test_walrus_running_counter():
count = 0
items = []
items.append("a")
assert (count := count + 1) == len(items)
items.append("b")
assert (count := count + 1) == len(items)
items.append("c")
assert (count := count + 1) == len(items)
assert count == 3
"""
)
result = pytester.runpytest()
assert result.ret == 0


@pytest.mark.skipif(
sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems"
)
Expand Down
Loading