Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -611,7 +611,7 @@ def handle_bulk_delete(
try:
# Handle deletion of specific (dag_id, dag_run_id, task_id, map_index) tuples
if delete_specific_map_index_task_keys:
_, matched_task_keys, not_found_task_keys = self._categorize_task_instances(
task_instances_map, matched_task_keys, not_found_task_keys = self._categorize_task_instances(
delete_specific_map_index_task_keys
)
not_found_task_ids = [
Expand All @@ -625,23 +625,10 @@ def handle_bulk_delete(
detail=f"The task instances with these identifiers: {not_found_task_ids} were not found",
)

for dag_id, run_id, task_id, map_index in matched_task_keys:
ti = (
self.session.execute(
select(TI).where(
TI.dag_id == dag_id,
TI.run_id == run_id,
TI.task_id == task_id,
TI.map_index == map_index,
)
)
.scalars()
.one_or_none()
)

if ti:
self.session.delete(ti)
results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]")
for task_key in matched_task_keys:
dag_id, run_id, task_id, map_index = task_key
self.session.delete(task_instances_map[task_key])
results.success.append(f"{dag_id}.{run_id}.{task_id}[{map_index}]")

# Handle deletion of all map indexes for certain (dag_id, dag_run_id, task_id) tuples
if delete_all_map_index_task_keys:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6719,6 +6719,46 @@ def test_bulk_delete_rejects_unauthorized_dag_ids_from_request_body(self, test_c
}
]

@pytest.mark.parametrize("task_count", [5, 10, 20])
def test_bulk_delete_query_count_scales_linearly_with_task_count(self, test_client, session, task_count):
# Regression guard for the N+1 fix in BulkTaskInstanceService.handle_bulk_delete:
# each extra task instance must add exactly QUERIES_PER_TASK_INSTANCE query (its DELETE),
# not 2 (DELETE + re-SELECT). A regression that re-queries inside the loop would make
# each run strictly exceed BASE_QUERY_COUNT + task_count * QUERIES_PER_TASK_INSTANCE.
QUERIES_PER_TASK_INSTANCE = 1
BASE_QUERY_COUNT = 5

self.create_task_instances(
session,
task_instances=[{"state": State.RUNNING, "map_indexes": tuple(range(task_count))}],
)
request_body = {
"actions": [
{
"action": "delete",
"entities": [
{"task_id": self.TASK_ID, "map_index": map_index} for map_index in range(task_count)
],
"action_on_non_existence": "fail",
}
]
}

with count_queries() as result:
response = test_client.patch(self.ENDPOINT_URL, json=request_body)

assert response.status_code == 200
assert len(response.json()["delete"]["success"]) == task_count

query_count = sum(result.values())
expected_query_count = BASE_QUERY_COUNT + task_count * QUERIES_PER_TASK_INSTANCE
assert query_count == expected_query_count, (
f"Bulk-delete query count {query_count} does not match expected {expected_query_count} "
f"for {task_count} task instances. "
f"A regression that re-queries each task instance would give "
f"~{BASE_QUERY_COUNT + task_count * 2} queries instead."
)

def test_should_respond_401(self, unauthenticated_test_client):
response = unauthenticated_test_client.patch(self.ENDPOINT_URL, json={})
assert response.status_code == 401
Expand Down