diff --git a/AUTHORS b/AUTHORS index d6d2737a4bf..41145722a55 100644 --- a/AUTHORS +++ b/AUTHORS @@ -92,6 +92,7 @@ Chris Mahoney Chris Lamb Chris NeJame Chris Rose +Chris Shucksmith (shuckc) Chris Wheeler Christian Boelsen Christian Clauss diff --git a/changelog/14445.bugfix.rst b/changelog/14445.bugfix.rst new file mode 100644 index 00000000000..ed6fc9455f4 --- /dev/null +++ b/changelog/14445.bugfix.rst @@ -0,0 +1 @@ +Fixed duplicate evaluation of walrus assignments after assert statement rewrite. diff --git a/src/_pytest/assertion/rewrite.py b/src/_pytest/assertion/rewrite.py index 99815b70cf1..db433fc2391 100644 --- a/src/_pytest/assertion/rewrite.py +++ b/src/_pytest/assertion/rewrite.py @@ -869,6 +869,8 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: self.statements: list[ast.stmt] = [] self.variables: list[str] = [] self.variable_counter = itertools.count() + # Clear walrus overwrite tracking — only valid within a single assert. + self.variables_overwrite[self.scope] = {} if self.enable_assertion_pass_hook: self.format_variables: list[str] = [] @@ -964,15 +966,16 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]: return self.statements def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]: - # This method handles the 'walrus operator' repr of the target - # name if it's a local variable or _should_repr_global_name() - # thinks it's acceptable. + # Return the NamedExpr node itself to preserve in-place evaluation order. + # For the display (used in failure messages), reference the target variable + # rather than re-evaluating the NamedExpr. locs = ast.Call(self.builtin("locals"), [], []) target_id = name.target.id + target_name = ast.Name(target_id, ast.Load()) inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) - dorepr = self.helper("_should_repr_global_name", name) + dorepr = self.helper("_should_repr_global_name", target_name) test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) - expr = ast.IfExp(test, self.display(name), ast.Constant(target_id)) + expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id)) return name, self.explanation_param(expr) def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]: @@ -998,20 +1001,9 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: for i, v in enumerate(boolop.values): if i: fail_inner: list[ast.stmt] = [] - # cond is set in a prior loop iteration below - self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821 + # expl_cond is set in a prior loop iteration below + self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821 self.expl_stmts = fail_inner - match v: - # Check if the left operand is an ast.NamedExpr and the value has already been visited - case ast.Compare( - left=ast.NamedExpr(target=ast.Name(id=target_id)) - ) if target_id in [ - e.id for e in boolop.values[:i] if hasattr(e, "id") - ]: - pytest_temp = self.variable() - self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment] - # mypy's false positive, we're checking that the 'target' attribute exists. - v.left.target.id = pytest_temp # type:ignore[attr-defined] self.push_format_context() res, expl = self.visit(v) body.append(ast.Assign([ast.Name(res_var, ast.Store())], res)) @@ -1022,6 +1014,12 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]: cond: ast.expr = res if is_or: cond = ast.UnaryOp(ast.Not(), cond) + # Save the condition value for the explanation path. A walrus + # in a later operand may modify the variable, but the saved + # value preserves the original truthiness for display purposes. + expl_cond_var = self.variable() + body.append(ast.Assign([ast.Name(expl_cond_var, ast.Store())], cond)) + expl_cond: ast.expr = ast.Name(expl_cond_var, ast.Load()) inner: list[ast.stmt] = [] self.statements.append(ast.If(cond, inner, [])) self.statements = body = inner @@ -1100,15 +1098,21 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]: def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: self.push_format_context() - # We first check if we have overwritten a variable in the previous assert - match comp.left: - case ast.Name(id=name_id) if name_id in self.variables_overwrite.get( - self.scope, {} - ): - comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment] - case ast.NamedExpr(target=ast.Name(id=target_id)): - self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment] - left_res, left_expl = self.visit(comp.left) + if isinstance(comp.left, ast.NamedExpr): + # Hoist the NamedExpr into a temp variable BEFORE visiting comparators, + # so that comparators referencing the walrus target see the assigned value. + # The assign evaluates (target := value), storing result in @py_assertN. + target_id = comp.left.target.id + left_res = self.assign(comp.left) + target_name = ast.Name(target_id, ast.Load()) + locs = ast.Call(self.builtin("locals"), [], []) + inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs]) + dorepr = self.helper("_should_repr_global_name", target_name) + test = ast.BoolOp(ast.Or(), [inlocs, dorepr]) + expr = ast.IfExp(test, self.display(left_res), ast.Constant(target_id)) + left_expl = self.explanation_param(expr) + else: + left_res, left_expl = self.visit(comp.left) if isinstance(comp.left, ast.Compare | ast.BoolOp): left_expl = f"({left_expl})" res_variables = [self.variable() for i in range(len(comp.ops))] @@ -1119,18 +1123,30 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]: syms: list[ast.expr] = [] results = [left_res] for i, op, next_operand in it: - match (next_operand, left_res): - case ( - ast.NamedExpr(target=ast.Name(id=target_id)), - ast.Name(id=name_id), - ) if target_id == name_id: - next_operand.target.id = self.variable() - self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment] + # If the comparator is a NamedExpr that overwrites the left operand's + # variable, save the left value in a temp BEFORE the comparison so + # the failure message can display the pre-walrus value. + if ( + isinstance(next_operand, ast.NamedExpr) + and isinstance(left_res, ast.Name) + and next_operand.target.id == left_res.id + ): + saved_left = self.variable() + self.statements.append( + ast.Assign([ast.Name(saved_left, ast.Store())], left_res) + ) + results[-1] = ast.Name(saved_left, ast.Load()) next_res, next_expl = self.visit(next_operand) if isinstance(next_operand, ast.Compare | ast.BoolOp): next_expl = f"({next_expl})" - results.append(next_res) + # For NamedExpr comparators, use the target variable in results + # (for the failure message) instead of the NamedExpr node itself, + # to avoid re-evaluating the walrus operator in the failure path. + if isinstance(next_operand, ast.NamedExpr): + results.append(ast.Name(next_operand.target.id, ast.Load())) + else: + results.append(next_res) sym = BINOP_MAP[op.__class__] syms.append(ast.Constant(sym)) expl = f"{left_expl} {sym} {next_expl}" diff --git a/testing/test_assertrewrite.py b/testing/test_assertrewrite.py index 2668001af65..1ad3ef2cfc0 100644 --- a/testing/test_assertrewrite.py +++ b/testing/test_assertrewrite.py @@ -1688,7 +1688,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and False is False)"]) + result.stdout.fnmatch_lines(["*assert not (False and False is False)"]) def test_assertion_walrus_operator_boolean_none_fails( self, pytester: Pytester @@ -1702,7 +1702,7 @@ def test_walrus_operator_change_boolean_value(): ) result = pytester.runpytest() assert result.ret == 1 - result.stdout.fnmatch_lines(["*assert not (True and None is None)"]) + result.stdout.fnmatch_lines(["*assert not (None and None is None)"]) def test_assertion_walrus_operator_value_changes_cleared_after_each_test( self, pytester: Pytester @@ -1846,6 +1846,57 @@ def test_2(): assert result.ret == 0 +@pytest.mark.skipif( + sys.version_info < (3, 8), reason="walrus operator not available in py<38" +) +class TestIssue14445WalrusDoubleEval: + """Test that walrus operator (:=) is not evaluated twice by assertion rewriting. + + The rewriter must not re-evaluate NamedExpr nodes when building the + failure message, as that causes side effects to fire twice. + """ + + def test_walrus_operator_not_double_evaluated(self, pytester: Pytester) -> None: + """Walrus assigns wrong value when rewriter evaluates := twice.""" + pytester.makepyfile( + """ + class Counter: + def __init__(self): + self.value = 0 + + def increment(self): + self.value += 1 + + def test_walrus_basic(): + c = Counter() + assert (before := c.value) == 0 + c.increment() + assert before != (after := c.value) + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + def test_walrus_operator_cumulative_not_doubled(self, pytester: Pytester) -> None: + """Cumulative walrus increments should not fire twice.""" + pytester.makepyfile( + """ + def test_walrus_running_counter(): + count = 0 + items = [] + items.append("a") + assert (count := count + 1) == len(items) + items.append("b") + assert (count := count + 1) == len(items) + items.append("c") + assert (count := count + 1) == len(items) + assert count == 3 + """ + ) + result = pytester.runpytest() + assert result.ret == 0 + + @pytest.mark.skipif( sys.maxsize <= (2**31 - 1), reason="Causes OverflowError on 32bit systems" )