diff --git a/scripts/ci/prek/check_trigger_serialize_init.py b/scripts/ci/prek/check_trigger_serialize_init.py index f1050a9340ea5..1a026dcffd2e6 100755 --- a/scripts/ci/prek/check_trigger_serialize_init.py +++ b/scripts/ci/prek/check_trigger_serialize_init.py @@ -36,16 +36,26 @@ through in-file base classes), flags any ``__init__`` parameter missing from the ``serialize()`` return dict. -Classes whose ``serialize()`` is built dynamically (``**spread`` of a non-``super()`` value, -``.update()``, returning a variable, ...) or that inherit ``__init__``/``serialize()`` from a base -class defined in another file cannot be resolved statically and are skipped -- the check never -guesses, so it produces no false positives. +Two ``serialize()`` shapes are resolved statically: + +1. **Direct dict literal**: ``return "", {"key": self.key, ...}``. +2. **Dict via local variable**: a single literal-dict initialization (``data = {...}`` or + ``data: dict = {...}``), optionally followed by ``data["k"] = ...`` subscript assignments and + ``data.update({...})`` calls with literal dict arguments, then ``return "", data``. + All key-adding paths (including conditional branches) are unioned -- a key that *could* appear + in the output is treated as preserved on the round-trip. + +Classes whose ``serialize()`` is built in any other dynamic way (``**spread`` of a non-``super()`` +value, reassignment of the return variable, ``.update()`` with a non-literal argument, ...) or +that inherit ``__init__``/``serialize()`` from a base class defined in another file cannot be +resolved statically and are skipped -- the check never guesses, so it produces no false positives. """ from __future__ import annotations import ast import sys +from collections.abc import Iterator from pathlib import Path from common_prek_utils import AIRFLOW_PROVIDERS_ROOT_PATH, console @@ -97,6 +107,21 @@ def _get_method(cls: ast.ClassDef, name: str) -> ast.FunctionDef | None: return None +def _walk_skip_nested(node: ast.AST) -> Iterator[ast.AST]: + """ + Yield descendants of *node* without descending into nested function/class/lambda scopes. + + Equivalent to ``ast.walk`` for typical control-flow nodes (``if``/``for``/``with`` bodies) + but stops at boundaries that introduce a new variable scope, so ``data["x"] = ...`` inside a + nested helper is not attributed to a same-named variable in the outer function. + """ + for child in ast.iter_child_nodes(node): + if isinstance(child, (ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef, ast.Lambda)): + continue + yield child + yield from _walk_skip_nested(child) + + class ModuleAnalyzer: """Resolves trigger __init__/serialize() pairs within a single module, following in-file bases.""" @@ -157,16 +182,25 @@ def _get_serialize_keys(self, cls: ast.ClassDef, _seen: set[str] | None = None) return None serialize, defining_cls = resolved - returns = [n for n in ast.walk(serialize) if isinstance(n, ast.Return)] + # Only count return statements in ``serialize()``'s own scope -- a return inside a + # nested helper function (or comprehension) belongs to that helper, not to serialize. + returns = [n for n in _walk_skip_nested(serialize) if isinstance(n, ast.Return)] if len(returns) != 1 or returns[0].value is None: return None ret = returns[0].value if not isinstance(ret, (ast.Tuple, ast.List)) or len(ret.elts) != 2: return None payload = ret.elts[1] - if not isinstance(payload, ast.Dict): - return None + if isinstance(payload, ast.Dict): + return self._extract_keys_from_dict_literal(payload, defining_cls, _seen) + if isinstance(payload, ast.Name): + return self._extract_keys_from_local_var(payload.id, serialize, defining_cls, _seen) + return None + def _extract_keys_from_dict_literal( + self, payload: ast.Dict, defining_cls: ast.ClassDef, _seen: set[str] + ) -> set[str] | None: + """Resolve keys from a ``{...}`` literal, accepting ``**super().serialize()[1]`` spreads.""" keys: set[str] = set() for key, value in zip(payload.keys, payload.values): if key is None: @@ -181,6 +215,96 @@ def _get_serialize_keys(self, cls: ast.ClassDef, _seen: set[str] | None = None) keys.add(key.value) return keys + def _extract_keys_from_local_var( + self, + var_name: str, + serialize: ast.FunctionDef, + defining_cls: ast.ClassDef, + _seen: set[str], + ) -> set[str] | None: + """ + Resolve keys for the dict-via-variable pattern: + + .. code-block:: python + + def serialize(self): + data = {"a": ...} # single literal-dict init + data["b"] = ... # subscript assignment, string-constant key + data.update({"c": ..., "d": ...}) # .update() with a literal dict argument + if cond: + data["e"] = ... # conditional branches are unioned + return "", data + + Returns the union of all keys that *could* appear in ``var_name``. Returns ``None`` if any + statement involving ``var_name`` is too dynamic to resolve (reassignment to a non-dict + expression, multiple literal-dict inits, dynamic subscript keys, ``.update()`` with a + non-literal argument, etc.). + """ + keys: set[str] = set() + init_seen = False + for node in _walk_skip_nested(serialize): + # Init: ``var = {...}`` (Assign with Dict RHS). + if isinstance(node, ast.Assign) and isinstance(node.value, ast.Dict): + for tgt in node.targets: + if isinstance(tgt, ast.Name) and tgt.id == var_name: + if init_seen: + return None # multiple literal-dict inits -- unresolvable + sub = self._extract_keys_from_dict_literal(node.value, defining_cls, _seen) + if sub is None: + return None + keys |= sub + init_seen = True + # Init: ``var: dict[...] = {...}`` (AnnAssign with Dict RHS). + elif isinstance(node, ast.AnnAssign) and isinstance(node.value, ast.Dict): + if isinstance(node.target, ast.Name) and node.target.id == var_name: + if init_seen: + return None + sub = self._extract_keys_from_dict_literal(node.value, defining_cls, _seen) + if sub is None: + return None + keys |= sub + init_seen = True + # Reassignment to a non-dict expression: ``var = something_else``. + elif isinstance(node, ast.Assign): + for tgt in node.targets: + if isinstance(tgt, ast.Name) and tgt.id == var_name: + return None + elif isinstance(node, ast.AnnAssign): + if ( + isinstance(node.target, ast.Name) + and node.target.id == var_name + and node.value is not None + ): + return None + # Now collect mutations: ``var["k"] = ...`` and ``var.update({...})``. + for node in _walk_skip_nested(serialize): + if isinstance(node, ast.Assign): + for tgt in node.targets: + if ( + isinstance(tgt, ast.Subscript) + and isinstance(tgt.value, ast.Name) + and tgt.value.id == var_name + ): + if not (isinstance(tgt.slice, ast.Constant) and isinstance(tgt.slice.value, str)): + return None # dynamic / non-string subscript key + keys.add(tgt.slice.value) + if ( + isinstance(node, ast.Call) + and isinstance(node.func, ast.Attribute) + and node.func.attr == "update" + and isinstance(node.func.value, ast.Name) + and node.func.value.id == var_name + ): + if node.keywords or len(node.args) != 1 or not isinstance(node.args[0], ast.Dict): + return None # ``.update()`` with kwargs or non-literal-dict argument + sub = self._extract_keys_from_dict_literal(node.args[0], defining_cls, _seen) + if sub is None: + return None + keys |= sub + if not init_seen: + return None # variable was never initialized to a literal dict in this function + return keys + def _get_super_serialize_keys( self, value: ast.expr, defining_cls: ast.ClassDef, _seen: set[str] ) -> set[str] | None: diff --git a/scripts/tests/ci/prek/test_check_trigger_serialize_init.py b/scripts/tests/ci/prek/test_check_trigger_serialize_init.py new file mode 100644 index 0000000000000..c1d1e2b235095 --- /dev/null +++ b/scripts/tests/ci/prek/test_check_trigger_serialize_init.py @@ -0,0 +1,326 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import textwrap +from pathlib import Path + +import check_trigger_serialize_init as check_module +import pytest +from check_trigger_serialize_init import ModuleAnalyzer, _get_init_param_names + + +@pytest.fixture +def analyzer_factory(tmp_path: Path, monkeypatch: pytest.MonkeyPatch): + """ + Factory: write Python source to a temp file rooted at *tmp_path* and return a + ModuleAnalyzer for it. + + ``ModuleAnalyzer.get_violations()`` computes paths relative to + ``AIRFLOW_PROVIDERS_ROOT_PATH`` for its violation report, so we redirect that + constant to *tmp_path* for the lifetime of each test. We resolve both ends of + the comparison so symlinked tmp dirs (``/var`` vs ``/private/var`` on macOS) + don't cause spurious ``relative_to`` failures. + """ + resolved_root = tmp_path.resolve() + monkeypatch.setattr(check_module, "AIRFLOW_PROVIDERS_ROOT_PATH", resolved_root) + + def _make(source: str, *, name: str = "trigger.py") -> ModuleAnalyzer: + path = resolved_root / name + path.write_text(textwrap.dedent(source)) + return ModuleAnalyzer(path) + + return _make + + +def _missing_init_params(analyzer: ModuleAnalyzer, class_name: str) -> set[str] | None: + """Compute ``__init__`` params not preserved by ``serialize()``. + + Returns ``None`` when ``serialize()`` is unresolvable. Mirrors the core check that + ``ModuleAnalyzer.get_violations()`` performs, without the report-path machinery. + """ + cls = analyzer.classes[class_name] + init_resolved = analyzer._resolve_method(cls, "__init__") + assert init_resolved is not None, f"{class_name}.__init__ should resolve in this test" + params = _get_init_param_names(init_resolved[0]) + serialize_keys = analyzer._get_serialize_keys(cls) + if serialize_keys is None: + return None + return params - serialize_keys + + +class TestDictLiteralResolver: + """The original resolver path: ``return "", {"key": ..., ...}``.""" + + def test_literal_dict_keys_resolved(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a, b, c): + self.a, self.b, self.c = a, b, c + def serialize(self): + return "x.FooTrigger", {"a": self.a, "b": self.b, "c": self.c} + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) + assert keys == {"a", "b", "c"} + + def test_missing_param_flagged(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a, b, c): + self.a, self.b, self.c = a, b, c + def serialize(self): + return "x.FooTrigger", {"a": self.a, "b": self.b} + """, + ) + assert _missing_init_params(analyzer, "FooTrigger") == {"c"} + assert analyzer.get_violations() == [("FooTrigger", ["c"])] + + def test_super_spread_resolved_via_in_file_base(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class BaseTrigger: + def __init__(self, a): + self.a = a + def serialize(self): + return "x.BaseTrigger", {"a": self.a} + + class ChildTrigger(BaseTrigger): + def __init__(self, a, b): + super().__init__(a) + self.b = b + def serialize(self): + return "x.ChildTrigger", {**super().serialize()[1], "b": self.b} + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["ChildTrigger"]) + assert keys == {"a", "b"} + + def test_non_constant_key_unresolvable(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + KEY = "a" + + class FooTrigger: + def __init__(self, a): + self.a = a + def serialize(self): + return "x.FooTrigger", {KEY: self.a} + """, + ) + assert analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) is None + + +class TestDictViaVariableResolver: + """New resolver: ``data = {...}; data['x'] = ...; data.update({...}); return ..., data``.""" + + def test_plain_assign_init(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a, b): + self.a, self.b = a, b + def serialize(self): + data = {"a": self.a, "b": self.b} + return "x.FooTrigger", data + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) + assert keys == {"a", "b"} + + def test_annotated_assign_init(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + from typing import Any + + class FooTrigger: + def __init__(self, a, b): + self.a, self.b = a, b + def serialize(self): + data: dict[str, Any] = {"a": self.a, "b": self.b} + return "x.FooTrigger", data + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) + assert keys == {"a", "b"} + + def test_subscript_assignment_adds_keys(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a, b, c): + self.a, self.b, self.c = a, b, c + def serialize(self): + data = {"a": self.a} + data["b"] = self.b + data["c"] = self.c + return "x.FooTrigger", data + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) + assert keys == {"a", "b", "c"} + + def test_update_with_literal_dict_adds_keys(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a, b, c): + self.a, self.b, self.c = a, b, c + def serialize(self): + data = {"a": self.a} + data.update({"b": self.b, "c": self.c}) + return "x.FooTrigger", data + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) + assert keys == {"a", "b", "c"} + + def test_conditional_branches_unioned(self, analyzer_factory) -> None: + """Mirrors the WorkflowTrigger pattern: keys conditionally added are still preserved.""" + analyzer = analyzer_factory( + """ + FLAG = True + + class FooTrigger: + def __init__(self, a, b, c): + self.a, self.b, self.c = a, b, c + def serialize(self): + data = {"a": self.a} + if FLAG: + data["b"] = self.b + else: + data["c"] = self.c + return "x.FooTrigger", data + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) + assert keys == {"a", "b", "c"} + + def test_missing_param_flagged_with_var_pattern(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a, b, c): + self.a, self.b, self.c = a, b, c + def serialize(self): + data = {"a": self.a, "b": self.b} + return "x.FooTrigger", data + """, + ) + assert _missing_init_params(analyzer, "FooTrigger") == {"c"} + assert analyzer.get_violations() == [("FooTrigger", ["c"])] + + @pytest.mark.parametrize( + "serialize_body", + [ + pytest.param( + ["data = {}", "data = some_helper()", "return 'x.T', data"], + id="reassign-non-dict", + ), + pytest.param( + ["data = {'a': 1}", "data = {'b': 2}", "return 'x.T', data"], + id="multiple-literal-inits", + ), + pytest.param( + ["k = 'a'", "data = {}", "data[k] = self.a", "return 'x.T', data"], + id="dynamic-subscript-key", + ), + pytest.param( + ["data = {}", "data.update(other)", "return 'x.T', data"], + id="update-with-non-literal", + ), + pytest.param( + ["data = {}", "data.update(a=1)", "return 'x.T', data"], + id="update-with-kwargs", + ), + pytest.param( + ["return 'x.T', data"], + id="never-initialized", + ), + ], + ) + def test_unresolvable_returns_none(self, analyzer_factory, serialize_body: list[str]) -> None: + """The resolver returns ``None`` rather than guess on any of these dynamic shapes.""" + # ``serialize_body`` is a list of statement lines; join with leading whitespace + # matching the column of ``{body}`` in the f-string template (20 spaces) so all + # body lines land at the same column after textwrap.dedent. + body = "\n ".join(serialize_body) + analyzer = analyzer_factory( + f""" + class FooTrigger: + def __init__(self, a): + self.a = a + def serialize(self): + {body} + """, + ) + assert analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) is None + + def test_nested_helper_does_not_pollute_outer_var(self, analyzer_factory) -> None: + """A same-named ``data`` inside a nested function must not contribute keys to the outer var.""" + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a): + self.a = a + def serialize(self): + data = {"a": self.a} + + def _bogus(): + data = {"b": 1} # different scope -- must be ignored + data["c"] = 2 + return data + + return "x.FooTrigger", data + """, + ) + keys = analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) + # Outer ``data`` only has ``a``; the nested helper's ``b`` and ``c`` must not leak in. + assert keys == {"a"} + + +class TestUnresolvedSerializeShapes: + """Verify shapes that should still be skipped post-extension.""" + + def test_multiple_returns_unresolvable(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a): + self.a = a + def serialize(self): + if self.a: + return "x.FooTrigger", {"a": self.a} + return "x.FooTrigger", {"a": None} + """, + ) + # Original constraint: exactly one return statement. + assert analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) is None + + def test_return_value_unrecognized_shape(self, analyzer_factory) -> None: + analyzer = analyzer_factory( + """ + class FooTrigger: + def __init__(self, a): + self.a = a + def serialize(self): + return self._build_payload() + """, + ) + assert analyzer._get_serialize_keys(analyzer.classes["FooTrigger"]) is None