diff --git a/airflow-core/newsfragments/67288.improvement.rst b/airflow-core/newsfragments/67288.improvement.rst new file mode 100644 index 0000000000000..03293e4ffa240 --- /dev/null +++ b/airflow-core/newsfragments/67288.improvement.rst @@ -0,0 +1 @@ +Speed up ``TaskGroup.topological_sort`` across Dag shapes (chain, diamond, layered, reverse-chain); benchmarks show roughly 2-8x faster on large groups. diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index d971c303c7c53..5db656019f1cb 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -18,7 +18,6 @@ from __future__ import annotations -import copy import functools import operator import weakref @@ -217,35 +216,79 @@ def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]: def topological_sort(self) -> list[DAGNode]: """ - Sorts children in topographical order. + Sort children topologically — a task always comes after its upstream dependencies. - A task in the result would come after any of its upstream dependencies. + See ``TaskGroup.topological_sort`` in task-sdk for the algorithm. Cycles are + treated as corrupt input: ``DAG.check_cycle`` rejects cyclic Dags before + serialization, so a cycle reaching this code indicates malformed serialized data, + and we raise ``ValueError`` rather than silently looping forever. """ - # This uses a modified version of Kahn's Topological Sort algorithm to - # not have to pre-compute the "in-degree" of the nodes. - graph_unsorted = copy.copy(self.children) - graph_sorted: list[DAGNode] = [] - if not self.children: - return graph_sorted - while graph_unsorted: - for node in list(graph_unsorted.values()): - for edge in node.upstream_list: - if edge.node_id in graph_unsorted: + children = self.children + if not children: + return [] + nodes = list(children.values()) + id_to_idx = {nid: i for i, nid in enumerate(children)} + projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)] + return self._sweep_projection(nodes, projected) + + def _project_child_deps( + self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int] + ) -> tuple[int, ...]: + upstream_ids = child.upstream_task_ids + if not upstream_ids: + return () + sib_deps: set[int] = set() + for edge_id in upstream_ids: + j = id_to_idx.get(edge_id) + if j is not None: + sib_deps.add(j) + continue + tg = self.dag.get_task(edge_id).task_group + while tg is not None: + j = id_to_idx.get(tg.node_id) + if j is not None: + sib_deps.add(j) + break + tg = tg.parent_group + sib_deps.discard(child_idx) + return tuple(sib_deps) + + def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: + n = len(nodes) + emitted = bytearray(n) + order: list[DAGNode] = [] + order_append = order.append + pending: list[int] = [] + pending_append = pending.append + for i in range(n): + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True + break + if blocked: + pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + while pending: + next_pending: list[int] = [] + next_pending_append = next_pending.append + for i in pending: + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True break - # Check for task's group is a child (or grand child) of this TG, - tg = edge.task_group - while tg: - if tg.node_id in graph_unsorted: - break - tg = tg.parent_group - - if tg: - # We are already going to visit that TG - break - else: - del graph_unsorted[node.node_id] - graph_sorted.append(node) - return graph_sorted + if blocked: + next_pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + if len(next_pending) == len(pending): + raise ValueError(f"A cyclic dependency occurred in dag: {self.dag_id}") + pending = next_pending + return order def add(self, node: DAGNode) -> DAGNode: # Set the TG first, as setting it might change the return value of node_id! diff --git a/airflow-core/tests/unit/utils/test_task_group.py b/airflow-core/tests/unit/utils/test_task_group.py index 3b62ad75a7290..ffc217fc0789d 100644 --- a/airflow-core/tests/unit/utils/test_task_group.py +++ b/airflow-core/tests/unit/utils/test_task_group.py @@ -1117,6 +1117,34 @@ def nested_topo(group): ] +def test_topological_sort_serialized_layered(): + """SerializedTaskGroup.topological_sort emits a valid order after DAG round-trip. + + Exercises the projected-sweep path on the serialization variant (which is otherwise + untested), using a layered shape that forces multi-pass behavior. + """ + with DAG("test_topo_sort_serialized", schedule=None, start_date=DEFAULT_DATE) as dag: + layers: list[list[BaseOperator]] = [] + for layer_idx in range(4): + cur = [EmptyOperator(task_id=f"L{layer_idx}_t{i}") for i in range(3)] + if layers: + for upstream in layers[-1]: + upstream >> cur + layers.append(cur) + + serialized = create_scheduler_dag(dag) + order = [node.node_id for node in serialized.task_group.topological_sort()] + position = {nid: i for i, nid in enumerate(order)} + + assert set(position) == {t.task_id for layer in layers for t in layer} + for layer_idx in range(len(layers) - 1): + for upstream in layers[layer_idx]: + for downstream in layers[layer_idx + 1]: + assert position[upstream.task_id] < position[downstream.task_id], ( + f"{upstream.task_id!r} must precede {downstream.task_id!r}, got {order!r}" + ) + + def test_task_group_arrow_with_setup_group(): with DAG(dag_id="setup_group_teardown_group") as dag: with TaskGroup("group_1") as g1: diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 50527f6b43bcd..67376cb817ae4 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -523,57 +523,90 @@ def hierarchical_alphabetical_sort(self): key=lambda node: (not isinstance(node, TaskGroup), node.node_id), ) - def topological_sort(self): + def topological_sort(self) -> list[DAGNode]: """ - Sorts children in topographical order, such that a task comes after any of its upstream dependencies. + Sort children topologically — a task always comes after its upstream dependencies. - :return: list of tasks in topological order + Projects each child's per-task upstream IDs onto sibling-level integer indices once, + then runs a greedy multi-pass sweep using a bytearray-backed emission flag. Equivalent + in emission order to the previous modified-Kahn implementation, but moves the per-edge + ``upstream_list`` materialization and ``parent_group`` walks out of the sweep's inner + loop so they happen once per call instead of once per outer-loop pass. """ - # This uses a modified version of Kahn's Topological Sort algorithm to - # not have to pre-compute the "in-degree" of the nodes. - graph_unsorted = copy.copy(self.children) - - graph_sorted: list[DAGNode] = [] - - # special case - if not self.children: - return graph_sorted - - # Run until the unsorted graph is empty. - while graph_unsorted: - # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain - # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the - # pair from the unsorted graph, and append it to the sorted graph. Note here that by using - # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify - # the unsorted graph as we move through it. - # - # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved - # during each pass through the graph. If not, we need to exit as the graph therefore can't be - # sorted. - acyclic = False - for node in list(graph_unsorted.values()): - for edge in node.upstream_list: - if edge.node_id in graph_unsorted: - break - # Check for task's group is a child (or grand child) of this TG, - tg = edge.task_group - while tg: - if tg.node_id in graph_unsorted: - break - tg = tg.parent_group - - if tg: - # We are already going to visit that TG + children = self.children + if not children: + return [] + nodes = list(children.values()) + id_to_idx = {nid: i for i, nid in enumerate(children)} + projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)] + return self._sweep_projection(nodes, projected) + + def _project_child_deps( + self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int] + ) -> tuple[int, ...]: + # Project one child's per-task upstream IDs onto sibling-level integer indices. + # Self-deps are filtered once at the end via ``discard`` so the inner loop stays tight. + upstream_ids = child.upstream_task_ids + if not upstream_ids: + return () + sib_deps: set[int] = set() + for edge_id in upstream_ids: + j = id_to_idx.get(edge_id) + if j is not None: + sib_deps.add(j) + continue + tg = self.dag.get_task(edge_id).task_group + while tg is not None: + j = id_to_idx.get(tg.node_id) + if j is not None: + sib_deps.add(j) + break + tg = tg.parent_group + sib_deps.discard(child_idx) + return tuple(sib_deps) + + def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: + # Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been emitted. + # Pass 1 iterates range(n) directly; only blocked nodes are recorded into + # ``pending`` and re-checked in subsequent passes. Avoids paying for a + # ``list(range(n))`` allocation on single-pass shapes (the common case) while + # still skipping already-emitted nodes on multi-pass shapes (e.g. a diamond's + # single trailing sink). + n = len(nodes) + emitted = bytearray(n) + order: list[DAGNode] = [] + order_append = order.append + pending: list[int] = [] + pending_append = pending.append + for i in range(n): + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True + break + if blocked: + pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + while pending: + next_pending: list[int] = [] + next_pending_append = next_pending.append + for i in pending: + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True break - else: - acyclic = True - del graph_unsorted[node.node_id] - graph_sorted.append(node) - - if not acyclic: + if blocked: + next_pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + if len(next_pending) == len(pending): raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") - - return graph_sorted + pending = next_pending + return order def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: """ diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py index 18c7f65faf2e4..d1ba11e3056c9 100644 --- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py +++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py @@ -957,3 +957,147 @@ def test_getitem_missing_is_key_error(self): with pytest.raises(KeyError): tg["nonexistent"] + + +# --- topological_sort: cross-shape correctness --- +# +# Mirrors the shapes covered by the benchmark gist referenced from PR #67288 +# (https://gist.github.com/shahar1/9c61dc9f34f7e77cd29cfb9d67af7ceb). +# Wall-clock timing is intentionally not asserted here — CI runners are too +# variable for ms thresholds to be meaningful. The gist above can be run +# manually to gauge performance. + + +def _make_chain(n: int) -> DAG: + with DAG(f"chain_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + prev = None + for i in range(n): + t = EmptyOperator(task_id=f"t{i}") + if prev is not None: + prev >> t + prev = t + return dag + + +def _make_reverse_chain(n: int) -> DAG: + with DAG(f"reverse_chain_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + tasks = [EmptyOperator(task_id=f"t{n - 1 - i}") for i in range(n)] + by_id = {t.task_id: t for t in tasks} + for i in range(n - 1): + by_id[f"t{i}"] >> by_id[f"t{i + 1}"] + return dag + + +def _make_diamond(n: int) -> DAG: + with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + root = EmptyOperator(task_id="root") + sink = EmptyOperator(task_id="sink") + middles = [EmptyOperator(task_id=f"m{i}") for i in range(max(n - 2, 1))] + root >> middles >> sink + return dag + + +def _make_independent(n: int) -> DAG: + with DAG(f"independent_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + for i in range(n): + EmptyOperator(task_id=f"t{i}") + return dag + + +def _make_layered(n: int, layers: int = 4) -> DAG: + per_layer = max(n // layers, 1) + with DAG(f"layered_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + prev_layer: list[EmptyOperator] = [] + for layer in range(layers): + cur = [EmptyOperator(task_id=f"L{layer}_t{i}") for i in range(per_layer)] + if prev_layer: + for upstream in prev_layer: + upstream >> cur + prev_layer = cur + return dag + + +def _make_nested_groups(n: int, depth: int = 3) -> DAG: + per_group = max(n // (depth * depth), 1) + with DAG(f"nested_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + + def build_group(level: int, idx: int) -> TaskGroup: + with TaskGroup(group_id=f"g{level}_{idx}") as tg: + prev = None + for i in range(per_group): + t = EmptyOperator(task_id=f"l{level}_g{idx}_t{i}") + if prev is not None: + prev >> t + prev = t + if level + 1 < depth: + inner_prev = None + for j in range(depth): + inner = build_group(level + 1, j) + if inner_prev is not None: + inner_prev >> inner + inner_prev = inner + return tg + + top_prev = None + for j in range(depth): + top = build_group(0, j) + if top_prev is not None: + top_prev >> top + top_prev = top + return dag + + +def _project_sibling(group: TaskGroup, upstream_task_id: str, child_id: str) -> str | None: + """Mirror of TaskGroup._project_child_deps' projection, returning a string ID.""" + children = group.children + if upstream_task_id in children: + return upstream_task_id if upstream_task_id != child_id else None + upstream = group.dag.get_task(upstream_task_id) + tg = upstream.task_group + while tg is not None: + if tg.node_id in children: + return tg.node_id if tg.node_id != child_id else None + tg = tg.parent_group + return None + + +def _walk_groups(tg: TaskGroup): + yield tg + for child in tg.children.values(): + if isinstance(child, TaskGroup): + yield from _walk_groups(child) + + +def _assert_valid_topological_order(group: TaskGroup, order: list[str]) -> None: + position = {node_id: i for i, node_id in enumerate(order)} + assert set(position) == set(group.children), ( + f"topological_sort output {order!r} does not cover children of {group.node_id!r}" + ) + for child_id, child in group.children.items(): + for upstream_id in child.upstream_task_ids: + sib = _project_sibling(group, upstream_id, child_id) + if sib is None: + continue + assert position[sib] < position[child_id], ( + f"In group {group.node_id!r}: sibling {sib!r} must precede {child_id!r}, got order {order!r}" + ) + + +@pytest.mark.parametrize( + ("shape", "builder"), + [ + ("chain", _make_chain), + ("rev-chain", _make_reverse_chain), + ("diamond", _make_diamond), + ("independent", _make_independent), + ("layered", _make_layered), + ("nested", _make_nested_groups), + ], +) +@pytest.mark.parametrize("n", [20, 100]) +def test_topological_sort_shape_correctness(shape, builder, n): + """topological_sort emits a valid order for every nested group across DAG shapes.""" + dag = builder(n) + for group in _walk_groups(dag.task_group): + order = [node.node_id for node in group.topological_sort()] + _assert_valid_topological_order(group, order)