From 0452161432243fd9905dcf27d49d65546755472c Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 21 May 2026 18:36:01 +0300 Subject: [PATCH 1/7] Speed up TaskGroup.topological_sort with int-indexed projected sweep The previous modified-Kahn implementation re-derived each child's upstream edges (materializing ``upstream_list`` and walking ``parent_group``) on every outer-loop pass. Project per-task upstream IDs onto sibling-level integer indices once up front, then run a greedy multi-pass sweep against that projection with a ``bytearray`` emission flag. Emission order is identical to the previous implementation; existing order-sensitive tests cover the contract. Same change is mirrored in SerializedTaskGroup. --- .../serialization/definitions/taskgroup.py | 104 ++++++++++---- .../src/airflow/sdk/definitions/taskgroup.py | 132 ++++++++++++------ 2 files changed, 163 insertions(+), 73 deletions(-) diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index d971c303c7c53..8d2362c6a6a0c 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -18,7 +18,6 @@ from __future__ import annotations -import copy import functools import operator import weakref @@ -217,35 +216,86 @@ def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]: def topological_sort(self) -> list[DAGNode]: """ - Sorts children in topographical order. + Sort children topologically — a task always comes after its upstream dependencies. - A task in the result would come after any of its upstream dependencies. + See ``TaskGroup.topological_sort`` in task-sdk for the algorithm. Cycle handling: + cycles are caught at deserialization time, so they should never reach this code; if + one does we raise ``ValueError`` rather than silently looping forever. """ - # This uses a modified version of Kahn's Topological Sort algorithm to - # not have to pre-compute the "in-degree" of the nodes. - graph_unsorted = copy.copy(self.children) - graph_sorted: list[DAGNode] = [] - if not self.children: - return graph_sorted - while graph_unsorted: - for node in list(graph_unsorted.values()): - for edge in node.upstream_list: - if edge.node_id in graph_unsorted: + children = self.children + if not children: + return [] + + nodes = list(children.values()) + n = len(nodes) + id_to_idx = {nid: i for i, nid in enumerate(children)} + + projected: list[tuple[int, ...]] = [()] * n + for i, child in enumerate(nodes): + projected[i] = self._project_child_deps(i, child, id_to_idx) + + return self._sweep_projection(nodes, projected) + + def _project_child_deps( + self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int] + ) -> tuple[int, ...]: + upstream_ids = child.upstream_task_ids + if not upstream_ids: + return () + sib_deps: set[int] = set() + for edge_id in upstream_ids: + j = id_to_idx.get(edge_id) + if j is not None: + if j != child_idx: + sib_deps.add(j) + continue + edge = self.dag.get_task(edge_id) + tg = edge.task_group + while tg is not None: + anc_idx = id_to_idx.get(tg.node_id) + if anc_idx is not None: + if anc_idx != child_idx: + sib_deps.add(anc_idx) + break + tg = tg.parent_group + return tuple(sib_deps) + + def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: + n = len(nodes) + emitted = bytearray(n) + order: list[DAGNode] = [] + order_append = order.append + pending: list[int] = [] + pending_append = pending.append + for i in range(n): + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True + break + if blocked: + pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + while pending: + next_pending: list[int] = [] + next_pending_append = next_pending.append + for i in pending: + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True break - # Check for task's group is a child (or grand child) of this TG, - tg = edge.task_group - while tg: - if tg.node_id in graph_unsorted: - break - tg = tg.parent_group - - if tg: - # We are already going to visit that TG - break - else: - del graph_unsorted[node.node_id] - graph_sorted.append(node) - return graph_sorted + if blocked: + next_pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + if len(next_pending) == len(pending): + raise ValueError(f"A cyclic dependency occurred in dag: {self.dag_id}") + pending = next_pending + return order def add(self, node: DAGNode) -> DAGNode: # Set the TG first, as setting it might change the return value of node_id! diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 50527f6b43bcd..0a491b64c3e28 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -523,57 +523,97 @@ def hierarchical_alphabetical_sort(self): key=lambda node: (not isinstance(node, TaskGroup), node.node_id), ) - def topological_sort(self): + def topological_sort(self) -> list[DAGNode]: """ - Sorts children in topographical order, such that a task comes after any of its upstream dependencies. + Sort children topologically — a task always comes after its upstream dependencies. - :return: list of tasks in topological order + Projects each child's per-task upstream IDs onto sibling-level integer indices once, + then runs a greedy multi-pass sweep using a bytearray-backed emission flag. Equivalent + in emission order to the previous modified-Kahn implementation, but moves the per-edge + ``upstream_list`` materialization and ``parent_group`` walks out of the sweep's inner + loop so they happen once per call instead of once per outer-loop pass. """ - # This uses a modified version of Kahn's Topological Sort algorithm to - # not have to pre-compute the "in-degree" of the nodes. - graph_unsorted = copy.copy(self.children) - - graph_sorted: list[DAGNode] = [] - - # special case - if not self.children: - return graph_sorted - - # Run until the unsorted graph is empty. - while graph_unsorted: - # Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain - # any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the - # pair from the unsorted graph, and append it to the sorted graph. Note here that by using - # the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify - # the unsorted graph as we move through it. - # - # We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved - # during each pass through the graph. If not, we need to exit as the graph therefore can't be - # sorted. - acyclic = False - for node in list(graph_unsorted.values()): - for edge in node.upstream_list: - if edge.node_id in graph_unsorted: - break - # Check for task's group is a child (or grand child) of this TG, - tg = edge.task_group - while tg: - if tg.node_id in graph_unsorted: - break - tg = tg.parent_group - - if tg: - # We are already going to visit that TG - break - else: - acyclic = True - del graph_unsorted[node.node_id] - graph_sorted.append(node) + children = self.children + if not children: + return [] - if not acyclic: + nodes = list(children.values()) + n = len(nodes) + id_to_idx = {nid: i for i, nid in enumerate(children)} + + projected: list[tuple[int, ...]] = [()] * n + for i, child in enumerate(nodes): + projected[i] = self._project_child_deps(i, child, id_to_idx) + + return self._sweep_projection(nodes, projected) + + def _project_child_deps( + self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int] + ) -> tuple[int, ...]: + # Project one child's per-task upstream IDs onto sibling-level integer indices. + upstream_ids = child.upstream_task_ids + if not upstream_ids: + return () + sib_deps: set[int] = set() + for edge_id in upstream_ids: + j = id_to_idx.get(edge_id) + if j is not None: + if j != child_idx: + sib_deps.add(j) + continue + edge = self.dag.get_task(edge_id) + tg = edge.task_group + while tg is not None: + anc_idx = id_to_idx.get(tg.node_id) + if anc_idx is not None: + if anc_idx != child_idx: + sib_deps.add(anc_idx) + break + tg = tg.parent_group + return tuple(sib_deps) + + def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: + # Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been emitted. + # Pass 1 iterates range(n) directly; only blocked nodes are recorded into + # ``pending`` and re-checked in subsequent passes. Avoids paying for a + # ``list(range(n))`` allocation on single-pass shapes (the common case) while + # still skipping already-emitted nodes on multi-pass shapes (e.g. a diamond's + # single trailing sink). + n = len(nodes) + emitted = bytearray(n) + order: list[DAGNode] = [] + order_append = order.append + pending: list[int] = [] + pending_append = pending.append + for i in range(n): + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True + break + if blocked: + pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + while pending: + next_pending: list[int] = [] + next_pending_append = next_pending.append + for i in pending: + blocked = False + for d in projected[i]: + if not emitted[d]: + blocked = True + break + if blocked: + next_pending_append(i) + continue + emitted[i] = 1 + order_append(nodes[i]) + if len(next_pending) == len(pending): raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}") - - return graph_sorted + pending = next_pending + return order def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]: """ From 77457a40acbe7d36bbd1b756987d29d8a5fe93d4 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 21 May 2026 18:39:55 +0300 Subject: [PATCH 2/7] Add newsfragment --- airflow-core/newsfragments/67288.improvement.rst | 1 + 1 file changed, 1 insertion(+) create mode 100644 airflow-core/newsfragments/67288.improvement.rst diff --git a/airflow-core/newsfragments/67288.improvement.rst b/airflow-core/newsfragments/67288.improvement.rst new file mode 100644 index 0000000000000..5ca44781caf70 --- /dev/null +++ b/airflow-core/newsfragments/67288.improvement.rst @@ -0,0 +1 @@ +Improve performance of ``TaskGroup.topological_sort`` by projecting each child's per-task upstream IDs onto sibling-level integer indices once up front and sweeping with a ``bytearray``-backed emission flag, moving per-edge work out of the outer loop. From 310aa0cde441b417f5c6987e9ddd2f1a0e74d388 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 21 May 2026 19:14:50 +0300 Subject: [PATCH 3/7] Add unit tests for correctness --- .../task_sdk/definitions/test_taskgroup.py | 144 ++++++++++++++++++ 1 file changed, 144 insertions(+) diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py index 18c7f65faf2e4..d7655d4bad0d1 100644 --- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py +++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py @@ -957,3 +957,147 @@ def test_getitem_missing_is_key_error(self): with pytest.raises(KeyError): tg["nonexistent"] + + +# --- topological_sort: cross-shape correctness --- +# +# Mirrors the shapes covered by the benchmark gist referenced from PR #67288 +# (https://gist.github.com/shahar1/9c61dc9f34f7e77cd29cfb9d67af7ceb). +# Wall-clock timing is intentionally not asserted here — CI runners are too +# variable for ms thresholds to be meaningful. Run dev/bench_topological_sort +# manually to gauge performance. + + +def _make_chain(n: int) -> DAG: + with DAG(f"chain_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + prev = None + for i in range(n): + t = EmptyOperator(task_id=f"t{i}") + if prev is not None: + prev >> t + prev = t + return dag + + +def _make_reverse_chain(n: int) -> DAG: + with DAG(f"reverse_chain_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + tasks = [EmptyOperator(task_id=f"t{n - 1 - i}") for i in range(n)] + by_id = {t.task_id: t for t in tasks} + for i in range(n - 1): + by_id[f"t{i}"] >> by_id[f"t{i + 1}"] + return dag + + +def _make_diamond(n: int) -> DAG: + with DAG(f"diamond_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + root = EmptyOperator(task_id="root") + sink = EmptyOperator(task_id="sink") + middles = [EmptyOperator(task_id=f"m{i}") for i in range(max(n - 2, 1))] + root >> middles >> sink + return dag + + +def _make_independent(n: int) -> DAG: + with DAG(f"independent_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + for i in range(n): + EmptyOperator(task_id=f"t{i}") + return dag + + +def _make_layered(n: int, layers: int = 4) -> DAG: + per_layer = max(n // layers, 1) + with DAG(f"layered_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + prev_layer: list = [] + for layer in range(layers): + cur = [EmptyOperator(task_id=f"L{layer}_t{i}") for i in range(per_layer)] + if prev_layer: + for upstream in prev_layer: + upstream >> cur + prev_layer = cur + return dag + + +def _make_nested_groups(n: int, depth: int = 3) -> DAG: + per_group = max(n // (depth * depth), 1) + with DAG(f"nested_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: + + def build_group(level: int, idx: int) -> TaskGroup: + with TaskGroup(group_id=f"g{level}_{idx}") as tg: + prev = None + for i in range(per_group): + t = EmptyOperator(task_id=f"l{level}_g{idx}_t{i}") + if prev is not None: + prev >> t + prev = t + if level + 1 < depth: + inner_prev = None + for j in range(depth): + inner = build_group(level + 1, j) + if inner_prev is not None: + inner_prev >> inner + inner_prev = inner + return tg + + top_prev = None + for j in range(depth): + top = build_group(0, j) + if top_prev is not None: + top_prev >> top + top_prev = top + return dag + + +def _project_sibling(group: TaskGroup, upstream_task_id: str, child_id: str) -> str | None: + """Mirror of TaskGroup._project_child_deps' projection, returning a string ID.""" + children = group.children + if upstream_task_id in children: + return upstream_task_id if upstream_task_id != child_id else None + upstream = group.dag.get_task(upstream_task_id) + tg = upstream.task_group + while tg is not None: + if tg.node_id in children: + return tg.node_id if tg.node_id != child_id else None + tg = tg.parent_group + return None + + +def _walk_groups(tg: TaskGroup): + yield tg + for child in tg.children.values(): + if isinstance(child, TaskGroup): + yield from _walk_groups(child) + + +def _assert_valid_topological_order(group: TaskGroup, order: list[str]) -> None: + position = {node_id: i for i, node_id in enumerate(order)} + assert set(position) == set(group.children), ( + f"topological_sort output {order!r} does not cover children of {group.node_id!r}" + ) + for child_id, child in group.children.items(): + for upstream_id in child.upstream_task_ids: + sib = _project_sibling(group, upstream_id, child_id) + if sib is None: + continue + assert position[sib] < position[child_id], ( + f"In group {group.node_id!r}: sibling {sib!r} must precede {child_id!r}, got order {order!r}" + ) + + +@pytest.mark.parametrize( + ("shape", "builder"), + [ + ("chain", _make_chain), + ("rev-chain", _make_reverse_chain), + ("diamond", _make_diamond), + ("independent", _make_independent), + ("layered", _make_layered), + ("nested", _make_nested_groups), + ], +) +@pytest.mark.parametrize("n", [20, 100, 500]) +def test_topological_sort_shape_correctness(shape, builder, n): + """topological_sort emits a valid order for every nested group across DAG shapes.""" + dag = builder(n) + for group in _walk_groups(dag.task_group): + order = [node.node_id for node in group.topological_sort()] + _assert_valid_topological_order(group, order) From 9eba47c6fa4f540c970b2f289aa7c04941d1d2a1 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 21 May 2026 22:36:46 +0300 Subject: [PATCH 4/7] Address self-review: serialization-side test, user-facing newsfragment Adds a round-trip test for SerializedTaskGroup.topological_sort (the serialization variant was previously untested), rewrites the newsfragment in user-facing terms, and cleans up a stale reference and type annotation in the task-sdk shape tests. --- .../newsfragments/67288.improvement.rst | 2 +- .../tests/unit/utils/test_task_group.py | 28 +++++++++++++++++++ .../task_sdk/definitions/test_taskgroup.py | 4 +-- 3 files changed, 31 insertions(+), 3 deletions(-) diff --git a/airflow-core/newsfragments/67288.improvement.rst b/airflow-core/newsfragments/67288.improvement.rst index 5ca44781caf70..03293e4ffa240 100644 --- a/airflow-core/newsfragments/67288.improvement.rst +++ b/airflow-core/newsfragments/67288.improvement.rst @@ -1 +1 @@ -Improve performance of ``TaskGroup.topological_sort`` by projecting each child's per-task upstream IDs onto sibling-level integer indices once up front and sweeping with a ``bytearray``-backed emission flag, moving per-edge work out of the outer loop. +Speed up ``TaskGroup.topological_sort`` across Dag shapes (chain, diamond, layered, reverse-chain); benchmarks show roughly 2-8x faster on large groups. diff --git a/airflow-core/tests/unit/utils/test_task_group.py b/airflow-core/tests/unit/utils/test_task_group.py index 3b62ad75a7290..ffc217fc0789d 100644 --- a/airflow-core/tests/unit/utils/test_task_group.py +++ b/airflow-core/tests/unit/utils/test_task_group.py @@ -1117,6 +1117,34 @@ def nested_topo(group): ] +def test_topological_sort_serialized_layered(): + """SerializedTaskGroup.topological_sort emits a valid order after DAG round-trip. + + Exercises the projected-sweep path on the serialization variant (which is otherwise + untested), using a layered shape that forces multi-pass behavior. + """ + with DAG("test_topo_sort_serialized", schedule=None, start_date=DEFAULT_DATE) as dag: + layers: list[list[BaseOperator]] = [] + for layer_idx in range(4): + cur = [EmptyOperator(task_id=f"L{layer_idx}_t{i}") for i in range(3)] + if layers: + for upstream in layers[-1]: + upstream >> cur + layers.append(cur) + + serialized = create_scheduler_dag(dag) + order = [node.node_id for node in serialized.task_group.topological_sort()] + position = {nid: i for i, nid in enumerate(order)} + + assert set(position) == {t.task_id for layer in layers for t in layer} + for layer_idx in range(len(layers) - 1): + for upstream in layers[layer_idx]: + for downstream in layers[layer_idx + 1]: + assert position[upstream.task_id] < position[downstream.task_id], ( + f"{upstream.task_id!r} must precede {downstream.task_id!r}, got {order!r}" + ) + + def test_task_group_arrow_with_setup_group(): with DAG(dag_id="setup_group_teardown_group") as dag: with TaskGroup("group_1") as g1: diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py index d7655d4bad0d1..6649aee8d0bd5 100644 --- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py +++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py @@ -964,7 +964,7 @@ def test_getitem_missing_is_key_error(self): # Mirrors the shapes covered by the benchmark gist referenced from PR #67288 # (https://gist.github.com/shahar1/9c61dc9f34f7e77cd29cfb9d67af7ceb). # Wall-clock timing is intentionally not asserted here — CI runners are too -# variable for ms thresholds to be meaningful. Run dev/bench_topological_sort +# variable for ms thresholds to be meaningful. The gist above can be run # manually to gauge performance. @@ -1007,7 +1007,7 @@ def _make_independent(n: int) -> DAG: def _make_layered(n: int, layers: int = 4) -> DAG: per_layer = max(n // layers, 1) with DAG(f"layered_{n}", schedule=None, start_date=DEFAULT_DATE) as dag: - prev_layer: list = [] + prev_layer: list[EmptyOperator] = [] for layer in range(layers): cur = [EmptyOperator(task_id=f"L{layer}_t{i}") for i in range(per_layer)] if prev_layer: From 395443aef6738a72245dfa04c5be4090667492fc Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 21 May 2026 22:39:36 +0300 Subject: [PATCH 5/7] Clarify topological_sort cycle-handling docstring MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Per Copilot review feedback, the previous wording said cycles are caught at "deserialization time" — but DAG.check_cycle runs at DAG parse time (via dagbag loading), not during from_dict/from_json. Reword to describe cycles reaching this code path as malformed serialized data, with the defensive ValueError still raised on detection. --- .../src/airflow/serialization/definitions/taskgroup.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index 8d2362c6a6a0c..70f2fb51e675b 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -218,9 +218,10 @@ def topological_sort(self) -> list[DAGNode]: """ Sort children topologically — a task always comes after its upstream dependencies. - See ``TaskGroup.topological_sort`` in task-sdk for the algorithm. Cycle handling: - cycles are caught at deserialization time, so they should never reach this code; if - one does we raise ``ValueError`` rather than silently looping forever. + See ``TaskGroup.topological_sort`` in task-sdk for the algorithm. Cycles are + treated as corrupt input: ``DAG.check_cycle`` rejects cyclic Dags before + serialization, so a cycle reaching this code indicates malformed serialized data, + and we raise ``ValueError`` rather than silently looping forever. """ children = self.children if not children: From 411f90ec2b53b68bb940f9100dba74723992cd95 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 21 May 2026 23:19:57 +0300 Subject: [PATCH 6/7] Compact TaskGroup.topological_sort projection without losing perf MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Tighten _project_child_deps and the topological_sort body: drop the per-edge child_idx comparison in favour of a single set.discard at the end, inline the get_task/.task_group chain, and replace the pre-alloc loop with a list comprehension. Net -15 lines per file with a small (~10%) speedup on the layered shape and no regression elsewhere. The hot _sweep_projection inner loop is left intact — earlier attempts to extract its duplicated body into a closure cost 10-30% on independent / layered shapes and were reverted. --- .../serialization/definitions/taskgroup.py | 22 ++++++------------ .../src/airflow/sdk/definitions/taskgroup.py | 23 +++++++------------ 2 files changed, 15 insertions(+), 30 deletions(-) diff --git a/airflow-core/src/airflow/serialization/definitions/taskgroup.py b/airflow-core/src/airflow/serialization/definitions/taskgroup.py index 70f2fb51e675b..5db656019f1cb 100644 --- a/airflow-core/src/airflow/serialization/definitions/taskgroup.py +++ b/airflow-core/src/airflow/serialization/definitions/taskgroup.py @@ -226,15 +226,9 @@ def topological_sort(self) -> list[DAGNode]: children = self.children if not children: return [] - nodes = list(children.values()) - n = len(nodes) id_to_idx = {nid: i for i, nid in enumerate(children)} - - projected: list[tuple[int, ...]] = [()] * n - for i, child in enumerate(nodes): - projected[i] = self._project_child_deps(i, child, id_to_idx) - + projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)] return self._sweep_projection(nodes, projected) def _project_child_deps( @@ -247,18 +241,16 @@ def _project_child_deps( for edge_id in upstream_ids: j = id_to_idx.get(edge_id) if j is not None: - if j != child_idx: - sib_deps.add(j) + sib_deps.add(j) continue - edge = self.dag.get_task(edge_id) - tg = edge.task_group + tg = self.dag.get_task(edge_id).task_group while tg is not None: - anc_idx = id_to_idx.get(tg.node_id) - if anc_idx is not None: - if anc_idx != child_idx: - sib_deps.add(anc_idx) + j = id_to_idx.get(tg.node_id) + if j is not None: + sib_deps.add(j) break tg = tg.parent_group + sib_deps.discard(child_idx) return tuple(sib_deps) def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: diff --git a/task-sdk/src/airflow/sdk/definitions/taskgroup.py b/task-sdk/src/airflow/sdk/definitions/taskgroup.py index 0a491b64c3e28..67376cb817ae4 100644 --- a/task-sdk/src/airflow/sdk/definitions/taskgroup.py +++ b/task-sdk/src/airflow/sdk/definitions/taskgroup.py @@ -536,21 +536,16 @@ def topological_sort(self) -> list[DAGNode]: children = self.children if not children: return [] - nodes = list(children.values()) - n = len(nodes) id_to_idx = {nid: i for i, nid in enumerate(children)} - - projected: list[tuple[int, ...]] = [()] * n - for i, child in enumerate(nodes): - projected[i] = self._project_child_deps(i, child, id_to_idx) - + projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)] return self._sweep_projection(nodes, projected) def _project_child_deps( self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int] ) -> tuple[int, ...]: # Project one child's per-task upstream IDs onto sibling-level integer indices. + # Self-deps are filtered once at the end via ``discard`` so the inner loop stays tight. upstream_ids = child.upstream_task_ids if not upstream_ids: return () @@ -558,18 +553,16 @@ def _project_child_deps( for edge_id in upstream_ids: j = id_to_idx.get(edge_id) if j is not None: - if j != child_idx: - sib_deps.add(j) + sib_deps.add(j) continue - edge = self.dag.get_task(edge_id) - tg = edge.task_group + tg = self.dag.get_task(edge_id).task_group while tg is not None: - anc_idx = id_to_idx.get(tg.node_id) - if anc_idx is not None: - if anc_idx != child_idx: - sib_deps.add(anc_idx) + j = id_to_idx.get(tg.node_id) + if j is not None: + sib_deps.add(j) break tg = tg.parent_group + sib_deps.discard(child_idx) return tuple(sib_deps) def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]: From c065bbfcd60c948051cf76f72e862ece1d2f8989 Mon Sep 17 00:00:00 2001 From: Shahar Epstein <60007259+shahar1@users.noreply.github.com> Date: Thu, 21 May 2026 23:44:14 +0300 Subject: [PATCH 7/7] Trim topological_sort shape tests to n in {20, 100} Drop n=500 from the parametrize grid. The algorithm has no n-dependent branches, so n=100 covers every code path; n=500 only re-runs the same loops with more iterations and added ~0.7s to the file without exercising new correctness behaviour (per Copilot review feedback on PR #67288). --- task-sdk/tests/task_sdk/definitions/test_taskgroup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py index 6649aee8d0bd5..d1ba11e3056c9 100644 --- a/task-sdk/tests/task_sdk/definitions/test_taskgroup.py +++ b/task-sdk/tests/task_sdk/definitions/test_taskgroup.py @@ -1094,7 +1094,7 @@ def _assert_valid_topological_order(group: TaskGroup, order: list[str]) -> None: ("nested", _make_nested_groups), ], ) -@pytest.mark.parametrize("n", [20, 100, 500]) +@pytest.mark.parametrize("n", [20, 100]) def test_topological_sort_shape_correctness(shape, builder, n): """topological_sort emits a valid order for every nested group across DAG shapes.""" dag = builder(n)