Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions changelog/14445.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Fixed assertion rewriting evaluating walrus operator (``:=``) expressions multiple times, causing incorrect test results when the expression had side effects (e.g., incrementing a counter or calling a function).
178 changes: 111 additions & 67 deletions src/_pytest/assertion/rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from __future__ import annotations

import ast
from collections import defaultdict
from collections.abc import Callable
from collections.abc import Iterable
from collections.abc import Iterator
Expand Down Expand Up @@ -58,20 +57,13 @@
from _pytest.assertion import AssertionState


class Sentinel:
pass


assertstate_key = StashKey["AssertionState"]()

# pytest caches rewritten pycs in pycache dirs
PYTEST_TAG = f"{sys.implementation.cache_tag}-pytest-{version}"
PYC_EXT = ".py" + ((__debug__ and "c") or "o")
PYC_TAIL = "." + PYTEST_TAG + PYC_EXT

# Special marker that denotes we have just left a scope definition
_SCOPE_END_MARKER = Sentinel()


class AssertionRewritingHook(importlib.abc.MetaPathFinder, importlib.abc.Loader):
"""PEP302/PEP451 import hook which rewrites asserts."""
Expand Down Expand Up @@ -652,14 +644,8 @@ class AssertionRewriter(ast.NodeVisitor):
.push_format_context() and .pop_format_context() which allows
to build another %-formatted string while already building one.

:scope: A tuple containing the current scope used for variables_overwrite.

:variables_overwrite: A dict filled with references to variables
that change value within an assert. This happens when a variable is
reassigned with the walrus operator

This state, except the variables_overwrite, is reset on every new assert
statement visited and used by the other visitors.
This state is reset on every new assert statement visited and used by
the other visitors.
"""

def __init__(
Expand All @@ -675,10 +661,6 @@ def __init__(
else:
self.enable_assertion_pass_hook = False
self.source = source
self.scope: tuple[ast.AST, ...] = ()
self.variables_overwrite: defaultdict[tuple[ast.AST, ...], dict[str, str]] = (
defaultdict(dict)
)

def run(self, mod: ast.Module) -> None:
"""Find all assert statements in *mod* and rewrite them."""
Expand Down Expand Up @@ -728,16 +710,9 @@ def run(self, mod: ast.Module) -> None:
mod.body[pos:pos] = imports

# Collect asserts.
self.scope = (mod,)
nodes: list[ast.AST | Sentinel] = [mod]
nodes: list[ast.AST] = [mod]
while nodes:
node = nodes.pop()
if isinstance(node, ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef):
self.scope = tuple((*self.scope, node))
nodes.append(_SCOPE_END_MARKER)
if node == _SCOPE_END_MARKER:
self.scope = self.scope[:-1]
continue
assert isinstance(node, ast.AST)
for name, field in ast.iter_fields(node):
if isinstance(field, list):
Expand Down Expand Up @@ -964,15 +939,17 @@ def visit_Assert(self, assert_: ast.Assert) -> list[ast.stmt]:
return self.statements

def visit_NamedExpr(self, name: ast.NamedExpr) -> tuple[ast.NamedExpr, str]:
# This method handles the 'walrus operator' repr of the target
# name if it's a local variable or _should_repr_global_name()
# thinks it's acceptable.
# Return the NamedExpr as-is so it evaluates in its natural position
# (preserving left-to-right evaluation order). For the explanation,
# reference the target variable (already assigned by the walrus) to
# avoid re-evaluating the expression.
locs = ast.Call(self.builtin("locals"), [], [])
target_id = name.target.id
target_name = ast.Name(target_id, ast.Load())
inlocs = ast.Compare(ast.Constant(target_id), [ast.In()], [locs])
dorepr = self.helper("_should_repr_global_name", name)
dorepr = self.helper("_should_repr_global_name", target_name)
test = ast.BoolOp(ast.Or(), [inlocs, dorepr])
expr = ast.IfExp(test, self.display(name), ast.Constant(target_id))
expr = ast.IfExp(test, self.display(target_name), ast.Constant(target_id))
return name, self.explanation_param(expr)

def visit_Name(self, name: ast.Name) -> tuple[ast.Name, str]:
Expand All @@ -998,32 +975,30 @@ def visit_BoolOp(self, boolop: ast.BoolOp) -> tuple[ast.Name, str]:
for i, v in enumerate(boolop.values):
if i:
fail_inner: list[ast.stmt] = []
# cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(cond, fail_inner, [])) # noqa: F821
# expl_cond is set in a prior loop iteration below
self.expl_stmts.append(ast.If(expl_cond, fail_inner, [])) # noqa: F821
self.expl_stmts = fail_inner
match v:
# Check if the left operand is an ast.NamedExpr and the value has already been visited
case ast.Compare(
left=ast.NamedExpr(target=ast.Name(id=target_id))
) if target_id in [
e.id for e in boolop.values[:i] if hasattr(e, "id")
]:
pytest_temp = self.variable()
self.variables_overwrite[self.scope][target_id] = v.left # type:ignore[assignment]
# mypy's false positive, we're checking that the 'target' attribute exists.
v.left.target.id = pytest_temp # type:ignore[attr-defined]
self.push_format_context()
res, expl = self.visit(v)
body.append(ast.Assign([ast.Name(res_var, ast.Store())], res))
expl_format = self.pop_format_context(ast.Constant(expl))
call = ast.Call(app, [expl_format], [])
self.expl_stmts.append(ast.Expr(call))
if i < levels:
cond: ast.expr = res
# Use res_var (already assigned above) rather than res directly,
# so that NamedExpr operands aren't evaluated a second time.
cond: ast.expr = ast.Name(res_var, ast.Load())
if is_or:
cond = ast.UnaryOp(ast.Not(), cond)
# Capture the condition in a stable temp for the explanation
# path — res_var is overwritten by subsequent operands.
cond_var = self.variable()
body.append(ast.Assign([ast.Name(cond_var, ast.Store())], cond))
expl_cond: ast.expr = ast.Name(cond_var, ast.Load()) # noqa: F841
inner: list[ast.stmt] = []
self.statements.append(ast.If(cond, inner, []))
self.statements.append(
ast.If(ast.Name(cond_var, ast.Load()), inner, [])
)
self.statements = body = inner
self.statements = save
self.expl_stmts = fail_save
Expand All @@ -1048,24 +1023,21 @@ def visit_BinOp(self, binop: ast.BinOp) -> tuple[ast.Name, str]:
return res, explanation

def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
# For method calls (obj.method()), produce a flat explanation like
# "where result = obj.method(args)" instead of nesting the attribute
# access as a separate "where method = obj.method" line.
if isinstance(call.func, ast.Attribute) and isinstance(call.func.ctx, ast.Load):
return self._visit_method_call(call)

new_func, func_expl = self.visit(call.func)
arg_expls = []
new_args = []
new_kwargs = []
for arg in call.args:
if isinstance(arg, ast.Name) and arg.id in self.variables_overwrite.get(
self.scope, {}
):
arg = self.variables_overwrite[self.scope][arg.id] # type:ignore[assignment]
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
match keyword.value:
case ast.Name(id=id) if id in self.variables_overwrite.get(
self.scope, {}
):
keyword.value = self.variables_overwrite[self.scope][id] # type:ignore[assignment]
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
Expand All @@ -1080,12 +1052,80 @@ def visit_Call(self, call: ast.Call) -> tuple[ast.Name, str]:
outer_expl = f"{res_expl}\n{{{res_expl} = {expl}\n}}"
return res, outer_expl

def _visit_method_call(self, call: ast.Call) -> tuple[ast.Name, str]:
r"""Handle obj.method(...) calls with a flat explanation format.

Produces: "result\n{result = obj_repr.method(args)\n}"
instead of nesting the bound-method intermediate.
"""
attr = call.func
assert isinstance(attr, ast.Attribute)

# Visit the object (receiver) for introspection.
obj_res, obj_expl = self.visit(attr.value)

# Visit arguments.
arg_expls = []
new_args = []
new_kwargs = []
for arg in call.args:
res, expl = self.visit(arg)
arg_expls.append(expl)
new_args.append(res)
for keyword in call.keywords:
res, expl = self.visit(keyword.value)
new_kwargs.append(ast.keyword(keyword.arg, res))
if keyword.arg:
arg_expls.append(keyword.arg + "=" + expl)
else:
arg_expls.append("**" + expl)

# Build the call using the rewritten object's attribute.
new_func = ast.Attribute(obj_res, attr.attr, ast.Load())
new_call = ast.copy_location(ast.Call(new_func, new_args, new_kwargs), call)
res = self.assign(new_call)
res_expl = self.explanation_param(self.display(res))
args_str = ", ".join(arg_expls)
expl = f"{res_expl}\n{{{res_expl} = {obj_expl}.{attr.attr}({args_str})\n}}"
return res, expl

def visit_Starred(self, starred: ast.Starred) -> tuple[ast.Starred, str]:
# A Starred node can appear in a function call.
res, expl = self.visit(starred.value)
new_starred = ast.Starred(res, starred.ctx)
return new_starred, "*" + expl

def visit_IfExp(self, ifexp: ast.IfExp) -> tuple[ast.Name, str]:
# Introspect the condition but keep branches as-is to preserve
# short-circuit semantics (only the selected branch is evaluated).
cond_res, cond_expl = self.visit(ifexp.test)
# Reconstruct the IfExp with the rewritten condition but original
# branches to avoid evaluating both sides.
res = self.assign(
ast.copy_location(ast.IfExp(cond_res, ifexp.body, ifexp.orelse), ifexp)
)
res_expl = self.explanation_param(self.display(res))
pat = "%s\n{%s = (... if %s else ...)\n}"
expl = pat % (res_expl, res_expl, cond_expl)
return res, expl

def visit_Subscript(self, subscript: ast.Subscript) -> tuple[ast.Name, str]:
if not isinstance(subscript.ctx, ast.Load):
return self.generic_visit(subscript)
# For Slice objects (a[1:3]), fall back to generic — decomposing
# start/stop/step is rarely useful in assertion messages.
if isinstance(subscript.slice, ast.Slice):
return self.generic_visit(subscript)
value, value_expl = self.visit(subscript.value)
slice_res, slice_expl = self.visit(subscript.slice)
res = self.assign(
ast.copy_location(ast.Subscript(value, slice_res, ast.Load()), subscript)
)
res_expl = self.explanation_param(self.display(res))
pat = "%s\n{%s = %s[%s]\n}"
expl = pat % (res_expl, res_expl, value_expl, slice_expl)
return res, expl

def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:
if not isinstance(attr.ctx, ast.Load):
return self.generic_visit(attr)
Expand All @@ -1100,17 +1140,13 @@ def visit_Attribute(self, attr: ast.Attribute) -> tuple[ast.Name, str]:

def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
self.push_format_context()
# We first check if we have overwritten a variable in the previous assert
match comp.left:
case ast.Name(id=name_id) if name_id in self.variables_overwrite.get(
self.scope, {}
):
comp.left = self.variables_overwrite[self.scope][name_id] # type: ignore[assignment]
case ast.NamedExpr(target=ast.Name(id=target_id)):
self.variables_overwrite[self.scope][target_id] = comp.left # type: ignore[assignment]
left_res, left_expl = self.visit(comp.left)
if isinstance(comp.left, ast.Compare | ast.BoolOp):
left_expl = f"({left_expl})"
# If the left operand is a NamedExpr, assign it to a temp so the
# walrus executes before any right-side expressions are hoisted.
if isinstance(left_res, ast.NamedExpr):
left_res = self.assign(left_res)
res_variables = [self.variable() for i in range(len(comp.ops))]
load_names: list[ast.expr] = [ast.Name(v, ast.Load()) for v in res_variables]
store_names = [ast.Name(v, ast.Store()) for v in res_variables]
Expand All @@ -1119,17 +1155,25 @@ def visit_Compare(self, comp: ast.Compare) -> tuple[ast.expr, str]:
syms: list[ast.expr] = []
results = [left_res]
for i, op, next_operand in it:
# If the next operand is a walrus that assigns to the same name as
# the current left_res, we must freeze left_res's value before the
# walrus modifies it.
match (next_operand, left_res):
case (
ast.NamedExpr(target=ast.Name(id=target_id)),
ast.Name(id=name_id),
) if target_id == name_id:
next_operand.target.id = self.variable()
self.variables_overwrite[self.scope][name_id] = next_operand # type: ignore[assignment]
left_res = self.assign(left_res)
results[-1] = left_res

next_res, next_expl = self.visit(next_operand)
if isinstance(next_operand, ast.Compare | ast.BoolOp):
next_expl = f"({next_expl})"
# Assign NamedExpr comparators to a temp so each walrus evaluates
# exactly once — critical for chained comparisons where the same
# node would otherwise be re-evaluated as left_res next iteration.
if isinstance(next_res, ast.NamedExpr):
next_res = self.assign(next_res)
results.append(next_res)
sym = BINOP_MAP[op.__class__]
syms.append(ast.Constant(sym))
Expand Down
3 changes: 1 addition & 2 deletions testing/python/raises_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1237,8 +1237,7 @@ def test_assert_matches() -> None:
match=wrap_escape(
"`ValueError()` is not an instance of `TypeError`\n"
"assert False\n"
" + where False = matches(ValueError())\n"
" + where matches = RaisesExc(TypeError).matches"
" + where False = RaisesExc(TypeError).matches(ValueError())"
),
):
# you'd need to do this arcane incantation
Expand Down
Loading
Loading