Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions airflow-core/newsfragments/67288.improvement.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Speed up ``TaskGroup.topological_sort`` across Dag shapes (chain, diamond, layered, reverse-chain); benchmarks show roughly 2-8x faster on large groups.
Comment thread
shahar1 marked this conversation as resolved.
97 changes: 70 additions & 27 deletions airflow-core/src/airflow/serialization/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@

from __future__ import annotations

import copy
import functools
import operator
import weakref
Expand Down Expand Up @@ -217,35 +216,79 @@ def iter_mapped_task_groups(self) -> Iterator[SerializedMappedTaskGroup]:

def topological_sort(self) -> list[DAGNode]:
"""
Sorts children in topographical order.
Sort children topologically — a task always comes after its upstream dependencies.

A task in the result would come after any of its upstream dependencies.
See ``TaskGroup.topological_sort`` in task-sdk for the algorithm. Cycles are
treated as corrupt input: ``DAG.check_cycle`` rejects cyclic Dags before
serialization, so a cycle reaching this code indicates malformed serialized data,
and we raise ``ValueError`` rather than silently looping forever.
"""
# This uses a modified version of Kahn's Topological Sort algorithm to
# not have to pre-compute the "in-degree" of the nodes.
graph_unsorted = copy.copy(self.children)
graph_sorted: list[DAGNode] = []
if not self.children:
return graph_sorted
while graph_unsorted:
for node in list(graph_unsorted.values()):
for edge in node.upstream_list:
if edge.node_id in graph_unsorted:
children = self.children
if not children:
return []
nodes = list(children.values())
id_to_idx = {nid: i for i, nid in enumerate(children)}
projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)]
return self._sweep_projection(nodes, projected)

def _project_child_deps(
self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
) -> tuple[int, ...]:
upstream_ids = child.upstream_task_ids
if not upstream_ids:
return ()
sib_deps: set[int] = set()
for edge_id in upstream_ids:
j = id_to_idx.get(edge_id)
if j is not None:
sib_deps.add(j)
continue
tg = self.dag.get_task(edge_id).task_group
while tg is not None:
j = id_to_idx.get(tg.node_id)
if j is not None:
sib_deps.add(j)
break
tg = tg.parent_group
sib_deps.discard(child_idx)
return tuple(sib_deps)

def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]:
n = len(nodes)
emitted = bytearray(n)
order: list[DAGNode] = []
order_append = order.append
pending: list[int] = []
pending_append = pending.append
for i in range(n):
blocked = False
for d in projected[i]:
if not emitted[d]:
blocked = True
break
if blocked:
pending_append(i)
continue
emitted[i] = 1
order_append(nodes[i])
while pending:
next_pending: list[int] = []
next_pending_append = next_pending.append
for i in pending:
blocked = False
for d in projected[i]:
if not emitted[d]:
blocked = True
break
# Check for task's group is a child (or grand child) of this TG,
tg = edge.task_group
while tg:
if tg.node_id in graph_unsorted:
break
tg = tg.parent_group

if tg:
# We are already going to visit that TG
break
else:
del graph_unsorted[node.node_id]
graph_sorted.append(node)
return graph_sorted
if blocked:
next_pending_append(i)
continue
emitted[i] = 1
order_append(nodes[i])
if len(next_pending) == len(pending):
raise ValueError(f"A cyclic dependency occurred in dag: {self.dag_id}")
pending = next_pending
return order

def add(self, node: DAGNode) -> DAGNode:
# Set the TG first, as setting it might change the return value of node_id!
Expand Down
28 changes: 28 additions & 0 deletions airflow-core/tests/unit/utils/test_task_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -1117,6 +1117,34 @@ def nested_topo(group):
]


def test_topological_sort_serialized_layered():
"""SerializedTaskGroup.topological_sort emits a valid order after DAG round-trip.

Exercises the projected-sweep path on the serialization variant (which is otherwise
untested), using a layered shape that forces multi-pass behavior.
"""
with DAG("test_topo_sort_serialized", schedule=None, start_date=DEFAULT_DATE) as dag:
layers: list[list[BaseOperator]] = []
for layer_idx in range(4):
cur = [EmptyOperator(task_id=f"L{layer_idx}_t{i}") for i in range(3)]
if layers:
for upstream in layers[-1]:
upstream >> cur
layers.append(cur)

serialized = create_scheduler_dag(dag)
order = [node.node_id for node in serialized.task_group.topological_sort()]
position = {nid: i for i, nid in enumerate(order)}

assert set(position) == {t.task_id for layer in layers for t in layer}
for layer_idx in range(len(layers) - 1):
for upstream in layers[layer_idx]:
for downstream in layers[layer_idx + 1]:
assert position[upstream.task_id] < position[downstream.task_id], (
f"{upstream.task_id!r} must precede {downstream.task_id!r}, got {order!r}"
)


def test_task_group_arrow_with_setup_group():
with DAG(dag_id="setup_group_teardown_group") as dag:
with TaskGroup("group_1") as g1:
Expand Down
125 changes: 79 additions & 46 deletions task-sdk/src/airflow/sdk/definitions/taskgroup.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,57 +523,90 @@ def hierarchical_alphabetical_sort(self):
key=lambda node: (not isinstance(node, TaskGroup), node.node_id),
)

def topological_sort(self):
def topological_sort(self) -> list[DAGNode]:
"""
Sorts children in topographical order, such that a task comes after any of its upstream dependencies.
Sort children topologically — a task always comes after its upstream dependencies.

:return: list of tasks in topological order
Projects each child's per-task upstream IDs onto sibling-level integer indices once,
then runs a greedy multi-pass sweep using a bytearray-backed emission flag. Equivalent
in emission order to the previous modified-Kahn implementation, but moves the per-edge
``upstream_list`` materialization and ``parent_group`` walks out of the sweep's inner
loop so they happen once per call instead of once per outer-loop pass.
"""
# This uses a modified version of Kahn's Topological Sort algorithm to
# not have to pre-compute the "in-degree" of the nodes.
graph_unsorted = copy.copy(self.children)

graph_sorted: list[DAGNode] = []

# special case
if not self.children:
return graph_sorted

# Run until the unsorted graph is empty.
while graph_unsorted:
# Go through each of the node/edges pairs in the unsorted graph. If a set of edges doesn't contain
# any nodes that haven't been resolved, that is, that are still in the unsorted graph, remove the
# pair from the unsorted graph, and append it to the sorted graph. Note here that by using
# the values() method for iterating, a copy of the unsorted graph is used, allowing us to modify
# the unsorted graph as we move through it.
#
# We also keep a flag for checking that graph is acyclic, which is true if any nodes are resolved
# during each pass through the graph. If not, we need to exit as the graph therefore can't be
# sorted.
acyclic = False
for node in list(graph_unsorted.values()):
for edge in node.upstream_list:
if edge.node_id in graph_unsorted:
break
# Check for task's group is a child (or grand child) of this TG,
tg = edge.task_group
while tg:
if tg.node_id in graph_unsorted:
break
tg = tg.parent_group

if tg:
# We are already going to visit that TG
children = self.children
if not children:
return []
nodes = list(children.values())
id_to_idx = {nid: i for i, nid in enumerate(children)}
projected = [self._project_child_deps(i, c, id_to_idx) for i, c in enumerate(nodes)]
return self._sweep_projection(nodes, projected)

def _project_child_deps(
self, child_idx: int, child: DAGNode, id_to_idx: dict[str, int]
) -> tuple[int, ...]:
# Project one child's per-task upstream IDs onto sibling-level integer indices.
# Self-deps are filtered once at the end via ``discard`` so the inner loop stays tight.
upstream_ids = child.upstream_task_ids
if not upstream_ids:
return ()
sib_deps: set[int] = set()
for edge_id in upstream_ids:
j = id_to_idx.get(edge_id)
if j is not None:
sib_deps.add(j)
continue
tg = self.dag.get_task(edge_id).task_group
while tg is not None:
j = id_to_idx.get(tg.node_id)
if j is not None:
sib_deps.add(j)
break
tg = tg.parent_group
sib_deps.discard(child_idx)
return tuple(sib_deps)

def _sweep_projection(self, nodes: list[DAGNode], projected: list[tuple[int, ...]]) -> list[DAGNode]:
# Greedy multi-pass sweep. emitted[i] == 1 iff nodes[i] has been emitted.
# Pass 1 iterates range(n) directly; only blocked nodes are recorded into
# ``pending`` and re-checked in subsequent passes. Avoids paying for a
# ``list(range(n))`` allocation on single-pass shapes (the common case) while
# still skipping already-emitted nodes on multi-pass shapes (e.g. a diamond's
# single trailing sink).
n = len(nodes)
emitted = bytearray(n)
order: list[DAGNode] = []
order_append = order.append
pending: list[int] = []
pending_append = pending.append
for i in range(n):
blocked = False
for d in projected[i]:
if not emitted[d]:
blocked = True
break
if blocked:
pending_append(i)
continue
emitted[i] = 1
order_append(nodes[i])
while pending:
next_pending: list[int] = []
next_pending_append = next_pending.append
for i in pending:
blocked = False
for d in projected[i]:
if not emitted[d]:
blocked = True
break
else:
acyclic = True
del graph_unsorted[node.node_id]
graph_sorted.append(node)

if not acyclic:
if blocked:
next_pending_append(i)
continue
emitted[i] = 1
order_append(nodes[i])
if len(next_pending) == len(pending):
raise AirflowDagCycleException(f"A cyclic dependency occurred in dag: {self.dag_id}")

return graph_sorted
pending = next_pending
return order

def iter_mapped_task_groups(self) -> Iterator[MappedTaskGroup]:
"""
Expand Down
Loading
Loading