Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
138 changes: 131 additions & 7 deletions scripts/ci/prek/check_trigger_serialize_init.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,26 @@
through in-file base classes), flags any ``__init__`` parameter missing from the ``serialize()``
return dict.

Classes whose ``serialize()`` is built dynamically (``**spread`` of a non-``super()`` value,
``.update()``, returning a variable, ...) or that inherit ``__init__``/``serialize()`` from a base
class defined in another file cannot be resolved statically and are skipped -- the check never
guesses, so it produces no false positives.
Two ``serialize()`` shapes are resolved statically:

1. **Direct dict literal**: ``return "<classpath>", {"key": self.key, ...}``.
2. **Dict via local variable**: a single literal-dict initialization (``data = {...}`` or
``data: dict = {...}``), optionally followed by ``data["k"] = ...`` subscript assignments and
``data.update({...})`` calls with literal dict arguments, then ``return "<classpath>", data``.
All key-adding paths (including conditional branches) are unioned -- a key that *could* appear
in the output is treated as preserved on the round-trip.

Classes whose ``serialize()`` is built in any other dynamic way (``**spread`` of a non-``super()``
value, reassignment of the return variable, ``.update()`` with a non-literal argument, ...) or
that inherit ``__init__``/``serialize()`` from a base class defined in another file cannot be
resolved statically and are skipped -- the check never guesses, so it produces no false positives.
"""

from __future__ import annotations

import ast
import sys
from collections.abc import Iterator
from pathlib import Path

from common_prek_utils import AIRFLOW_PROVIDERS_ROOT_PATH, console
Expand Down Expand Up @@ -97,6 +107,21 @@ def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef | None:
return None


def _walk_skip_nested(node: ast.AST) -> Iterator[ast.AST]:
"""
Yield descendants of *node* without descending into nested function/class/lambda scopes.

Equivalent to ``ast.walk`` for typical control-flow nodes (``if``/``for``/``with`` bodies)
but stops at boundaries that introduce a new variable scope, so ``data["x"] = ...`` inside a
nested helper is not attributed to a same-named variable in the outer function.
"""
for child in ast.iter_child_nodes(node):
if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)):
continue
yield child
yield from _walk_skip_nested(child)


class ModuleAnalyzer:
"""Resolves trigger __init__/serialize() pairs within a single module, following in-file bases."""

Expand Down Expand Up @@ -157,16 +182,25 @@ def _get_serialize_keys(self, cls: ast.ClassDef, _seen: set[str] | None = None)
return None
serialize, defining_cls = resolved

returns = [n for n in ast.walk(serialize) if isinstance(n, ast.Return)]
# Only count return statements in ``serialize()``'s own scope -- a return inside a
# nested helper function (or comprehension) belongs to that helper, not to serialize.
returns = [n for n in _walk_skip_nested(serialize) if isinstance(n, ast.Return)]
if len(returns) != 1 or returns[0].value is None:
return None
ret = returns[0].value
if not isinstance(ret, (ast.Tuple, ast.List)) or len(ret.elts) != 2:
return None
payload = ret.elts[1]
if not isinstance(payload, ast.Dict):
return None
if isinstance(payload, ast.Dict):
return self._extract_keys_from_dict_literal(payload, defining_cls, _seen)
if isinstance(payload, ast.Name):
return self._extract_keys_from_local_var(payload.id, serialize, defining_cls, _seen)
return None

def _extract_keys_from_dict_literal(
self, payload: ast.Dict, defining_cls: ast.ClassDef, _seen: set[str]
) -> set[str] | None:
"""Resolve keys from a ``{...}`` literal, accepting ``**super().serialize()[1]`` spreads."""
keys: set[str] = set()
for key, value in zip(payload.keys, payload.values):
if key is None:
Expand All @@ -181,6 +215,96 @@ def _get_serialize_keys(self, cls: ast.ClassDef, _seen: set[str] | None = None)
keys.add(key.value)
return keys

def _extract_keys_from_local_var(
self,
var_name: str,
serialize: ast.FunctionDef,
defining_cls: ast.ClassDef,
_seen: set[str],
) -> set[str] | None:
"""
Resolve keys for the dict-via-variable pattern:

.. code-block:: python

def serialize(self):
data = {"a": ...} # single literal-dict init
data["b"] = ... # subscript assignment, string-constant key
data.update({"c": ..., "d": ...}) # .update() with a literal dict argument
if cond:
data["e"] = ... # conditional branches are unioned
return "<classpath>", data

Returns the union of all keys that *could* appear in ``var_name``. Returns ``None`` if any
statement involving ``var_name`` is too dynamic to resolve (reassignment to a non-dict
expression, multiple literal-dict inits, dynamic subscript keys, ``.update()`` with a
non-literal argument, etc.).
"""
keys: set[str] = set()
init_seen = False
for node in _walk_skip_nested(serialize):
# Init: ``var = {...}`` (Assign with Dict RHS).
if isinstance(node, ast.Assign) and isinstance(node.value, ast.Dict):
for tgt in node.targets:
if isinstance(tgt, ast.Name) and tgt.id == var_name:
if init_seen:
return None # multiple literal-dict inits -- unresolvable
sub = self._extract_keys_from_dict_literal(node.value, defining_cls, _seen)
if sub is None:
return None
keys |= sub
init_seen = True
# Init: ``var: dict[...] = {...}`` (AnnAssign with Dict RHS).
elif isinstance(node, ast.AnnAssign) and isinstance(node.value, ast.Dict):
if isinstance(node.target, ast.Name) and node.target.id == var_name:
if init_seen:
return None
sub = self._extract_keys_from_dict_literal(node.value, defining_cls, _seen)
if sub is None:
return None
keys |= sub
init_seen = True
# Reassignment to a non-dict expression: ``var = something_else``.
elif isinstance(node, ast.Assign):
for tgt in node.targets:
if isinstance(tgt, ast.Name) and tgt.id == var_name:
return None
elif isinstance(node, ast.AnnAssign):
if (
isinstance(node.target, ast.Name)
and node.target.id == var_name
and node.value is not None
):
return None
# Now collect mutations: ``var["k"] = ...`` and ``var.update({...})``.
for node in _walk_skip_nested(serialize):
if isinstance(node, ast.Assign):
for tgt in node.targets:
if (
isinstance(tgt, ast.Subscript)
and isinstance(tgt.value, ast.Name)
and tgt.value.id == var_name
):
if not (isinstance(tgt.slice, ast.Constant) and isinstance(tgt.slice.value, str)):
return None # dynamic / non-string subscript key
keys.add(tgt.slice.value)
if (
isinstance(node, ast.Call)
and isinstance(node.func, ast.Attribute)
and node.func.attr == "update"
and isinstance(node.func.value, ast.Name)
and node.func.value.id == var_name
):
if node.keywords or len(node.args) != 1 or not isinstance(node.args[0], ast.Dict):
return None # ``.update()`` with kwargs or non-literal-dict argument
sub = self._extract_keys_from_dict_literal(node.args[0], defining_cls, _seen)
if sub is None:
return None
keys |= sub
if not init_seen:
return None # variable was never initialized to a literal dict in this function
return keys

def _get_super_serialize_keys(
self, value: ast.expr, defining_cls: ast.ClassDef, _seen: set[str]
) -> set[str] | None:
Expand Down
Loading
Loading