diff --git a/changelog/14445.bugfix.rst b/changelog/14445.bugfix.rst new file mode 100644 index 00000000000..aaae0c615f5 --- /dev/null +++ b/changelog/14445.bugfix.rst @@ -0,0 +1 @@ +Fixed assertion rewriting evaluating walrus operator (``:=``) expressions multiple times, causing incorrect test results when the expression had side effects (e.g., incrementing a counter or calling a function). diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 99815b70cf1..3fa3217f6e0 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -3,7 +3,6 @@ from __future__ import annotations import ast -from collections import defaultdict from collections.abc import Callable from collections.abc import Iterable from collections.abc import Iterator @@ -58,10 +57,6 @@ from _pytest.assertion import AssertionState -class Sentinel: - pass - - assertstate_key = StashKey["AssertionState"]() # pytest caches rewritten pycs in pycache dirs @@ -69,9 +64,6 @@ class Sentinel: PYC_EXT = ".py" + ((__debug__ and "c") or "o") PYC_TAIL = "." + PYTEST_TAG + PYC_EXT -# Special marker that denotes we have just left a scope definition -_SCOPE_END_MARKER = Sentinel() - class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader): """PEP302/PEP451 import hook which rewrites asserts.""" @@ -652,14 +644,8 @@ class AssertionRewriter(ast.NodeVisitor): .push_format_context() and .pop_format_context() which allows to build another %-formatted string while already building one. - :scope: A tuple containing the current scope used for variables_overwrite. - - :variables_overwrite: A dict filled with references to variables - that change value within an assert. This happens when a variable is - reassigned with the walrus operator - - This state, except the variables_overwrite, is reset on every new assert - statement visited and used by the other visitors. + This state is reset on every new assert statement visited and used by + the other visitors. """ def __init__( @@ -675,10 +661,6 @@ def __init__( else: self.enable_assertion_pass_hook = False self.source = source - self.scope: tuple[ast.AST, ...] = () - self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = ( - defaultdict(dict) - ) def run(self, mod: ast.Module) -> None: """Find all assert statements in *mod* and rewrite them.""" @@ -728,16 +710,9 @@ def run(self, mod: ast.Module) -> None: mod.body[pos:pos] = imports # Collect asserts. - self.scope = (mod,) - nodes: list[ast.AST | Sentinel] = [mod] + nodes: list[ast.AST] = [mod] while nodes: node = nodes.pop() - if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef): - self.scope = tuple((*self.scope, node)) - nodes.append(_SCOPE_END_MARKER) - if node == _SCOPE_END_MARKER: - self.scope = self.scope[:-1] - continue assert isinstance(node, ast.AST) for name, field in ast.iter_fields(node): if isinstance(field, list): @@ -964,15 +939,17 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: return self.statements def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: - # This method handles the 'walrus operator' repr of the target - # name if it's a local variable or _should_repr_global_name() - # thinks it's acceptable. + # Return the NamedExpr as-is so it evaluates in its natural position + # (preserving left-to-right evaluation order). For the explanation, + # reference the target variable (already assigned by the walrus) to + # avoid re-evaluating the expression. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id + target_name = ast.Name(target_id, ast.Load()) inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) - dorepr = self.helper("_should_repr_global_name", name) + dorepr = self.helper("_should_repr_global_name", target_name) test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) - expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) + expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id)) return name, self.explanation_param(expr) def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]: @@ -998,20 +975,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: for i, v in enumerate(boolop.values): if i: fail_inner: list[ast.stmt] = [] - # cond is set in a prior loop iteration below - self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 + # expl_cond is set in a prior loop iteration below + self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner - match v: - # Check if the left operand is an ast.NamedExpr and the value has already been visited - case ast.Compare( - left=ast.NamedExpr(target=ast.Name(id=target_id)) - ) if target_id in [ - e.id for e in boolop.values[:i] if hasattr(e, "id") - ]: - pytest_temp = self.variable() - self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] - # mypy's false positive, we're checking that the 'target' attribute exists. - v.left.target.id = pytest_temp # type:ignore[attr-defined] self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1019,11 +985,20 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: call = ast.Call(app, [expl_format], []) self.expl_stmts.append(ast.Expr(call)) if i < levels: - cond: ast.expr = res + # Use res_var (already assigned above) rather than res directly, + # so that NamedExpr operands aren't evaluated a second time. + cond: ast.expr = ast.Name(res_var, ast.Load()) if is_or: cond = ast.UnaryOp(ast.Not(), cond) + # Capture the condition in a stable temp for the explanation + # path — res_var is overwritten by subsequent operands. + cond_var = self.variable() + body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond)) + expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841 inner: list[ast.stmt] = [] - self.statements.append(ast.If(cond, inner, [])) + self.statements.append( + ast.If(ast.Name(cond_var, ast.Load()), inner, []) + ) self.statements = body = inner self.statements = save self.expl_stmts = fail_save @@ -1053,19 +1028,10 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]: new_args = [] new_kwargs = [] for arg in call.args: - if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get( - self.scope, {} - ): - arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment] res, expl = self.visit(arg) arg_expls.append(expl) new_args.append(res) for keyword in call.keywords: - match keyword.value: - case ast.Name(id=id) if id in self.variables_overwrite.get( - self.scope, {} - ): - keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment] res, expl = self.visit(keyword.value) new_kwargs.append(ast.keyword(keyword.arg, res)) if keyword.arg: @@ -1100,17 +1066,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: self.push_format_context() - # We first check if we have overwritten a variable in the previous assert - match comp.left: - case ast.Name(id=name_id) if name_id in self.variables_overwrite.get( - self.scope, {} - ): - comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment] - case ast.NamedExpr(target=ast.Name(id=target_id)): - self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment] left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, ast.Compare | ast.BoolOp): left_expl = f"({left_expl})" + # If the left operand is a NamedExpr, assign it to a temp so the + # walrus executes before any right-side expressions are hoisted. + if isinstance(left_res, ast.NamedExpr): + left_res = self.assign(left_res) res_variables = [self.variable() for i in range(len(comp.ops))] load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables] store_names = [ast.Name(v, ast.Store()) for v in res_variables] @@ -1119,17 +1081,25 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms: list[ast.expr] = [] results = [left_res] for i, op, next_operand in it: + # If the next operand is a walrus that assigns to the same name as + # the current left_res, we must freeze left_res's value before the + # walrus modifies it. match (next_operand, left_res): case ( ast.NamedExpr(target=ast.Name(id=target_id)), ast.Name(id=name_id), ) if target_id == name_id: - next_operand.target.id = self.variable() - self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] + left_res = self.assign(left_res) + results[-1] = left_res next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): next_expl = f"({next_expl})" + # Assign NamedExpr comparators to a temp so each walrus evaluates + # exactly once — critical for chained comparisons where the same + # node would otherwise be re-evaluated as left_res next iteration. + if isinstance(next_res, ast.NamedExpr): + next_res = self.assign(next_res) results.append(next_res) sym = BINOP_MAP[op.__class__] syms.append(ast.Constant(sym)) diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 2668001af65..11995321826 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) + result.stdout.fnmatch_lines(["*assert not (False and False is False)"]) def test_assertion_walrus_operator_boolean_none_fails( self, pytester: Pytester @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) + result.stdout.fnmatch_lines(["*assert not (None and None is None)"]) def test_assertion_walrus_operator_value_changes_cleared_after_each_test( self, pytester: Pytester @@ -1846,6 +1846,108 @@ def test_2(): assert result.ret == 0 +class TestIssue14445: + """Regression tests for #14445: walrus operator double evaluation.""" + + def test_walrus_no_double_eval_basic(self, pytester: Pytester) -> None: + """Walrus captures the value at assignment time, not re-evaluated later.""" + pytester.makepyfile( + """ + class Counter: + def __init__(self): + self.value = 0 + def increment(self): + self.value += 1 + + def test_walrus_in_assertion_basic(): + c = Counter() + assert (before := c.value) == 0 + c.increment() + assert before != (after := c.value) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_running_counter(self, pytester: Pytester) -> None: + """Walrus increments fire exactly once per assert statement.""" + pytester.makepyfile( + """ + def test_walrus_running_counter(): + count = 0 + items = [] + items.append("a") + assert (count := count + 1) == len(items) + items.append("b") + assert (count := count + 1) == len(items) + items.append("c") + assert (count := count + 1) == len(items) + assert count == 3 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_in_function_call(self, pytester: Pytester) -> None: + """Walrus in function call arguments not evaluated twice.""" + pytester.makepyfile( + """ + call_count = 0 + + def side_effect(): + global call_count + call_count += 1 + return call_count + + def test_walrus_side_effect(): + assert (val := side_effect()) == 1 + assert val == 1 + assert (val := side_effect()) == 2 + assert val == 2 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_in_boolop(self, pytester: Pytester) -> None: + """Bare walrus as a BoolOp operand must not be evaluated twice.""" + pytester.makepyfile( + """ + call_count = 0 + + def side_effect(): + global call_count + call_count += 1 + return call_count + + def test_walrus_boolop(): + assert (x := side_effect()) and x == 1 + assert call_count == 1 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_no_double_eval_chained_compare(self, pytester: Pytester) -> None: + """Same walrus target in chained comparison must evaluate each once.""" + pytester.makepyfile( + """ + call_count = 0 + + def track(value): + global call_count + call_count += 1 + return value + + def test_walrus_chained(): + assert (x := track(1)) < (x := track(3)) < (x := track(5)) + assert call_count == 3 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" )