diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 94999538fc1..dc0171db731 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -7,14 +7,13 @@ from django.core.exceptions import ValidationError from django.db import transaction -from django.db.models import signals from django.utils import timezone +from dojo import tag_inheritance from dojo.importers.base_location_manager import BaseLocationManager from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.models import Product, _manage_inherited_tags -from dojo.tags_signals import make_inherited_tags_sticky from dojo.tools.locations import LocationData from dojo.url.models import URL from dojo.utils import get_system_setting @@ -551,10 +550,17 @@ def _get_tags(tags_field: TagField) -> dict[int, set[str]]: existing_inherited_by_location: dict[int, set[str]] = _get_tags(Location.inherited_tags) existing_tags_by_location: dict[int, set[str]] = _get_tags(Location.tags) - # Perform the bulk updates. First, though, disconnect the make_inherited_tags_sticky signal on Location.tags - # while updating, otherwise each (inherited_)tags.set() will trigger, defeating the purpose of this bulk update. - disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=Location.tags.through) - try: + # Perform the bulk updates inside a `tag_inheritance.batch()` context. + # While the batch is active, signal handlers in `dojo/tags_signals.py` + # short-circuit per-row inheritance work that would otherwise fire on + # every `(inherited_)tags.set()` and defeat the bulk update. + # + # This replaces a previous `signals.m2m_changed.disconnect(...)` / + # `connect(...)` dance which was process-global and therefore unsafe + # under threaded gunicorn / Celery thread pools / ASGI threadpools: + # while disconnected, every thread in the process lost sticky + # enforcement. Thread-local batch state avoids that hazard. + with tag_inheritance.batch_mode(): for location in locations: target_tag_names: set[str] = set() for pid in product_ids_by_location[location.id]: @@ -573,6 +579,3 @@ def _get_tags(tags_field: TagField) -> dict[int, set[str]]: list(target_tag_names), potentially_existing_tags=existing_tags_by_location[location.id], ) - finally: - if disconnected: - signals.m2m_changed.connect(make_inherited_tags_sticky, sender=Location.tags.through) diff --git a/dojo/tag_inheritance.py b/dojo/tag_inheritance.py new file mode 100644 index 00000000000..e9d0e98a5fc --- /dev/null +++ b/dojo/tag_inheritance.py @@ -0,0 +1,54 @@ +""" +Tag inheritance — central coordination module. + +Provides a thread-local ``batch()`` context manager that suppresses +per-instance inheritance work driven by ``m2m_changed`` and ``post_save`` +signals. While inside a batch, the signal handlers in +``dojo/tags_signals.py`` early-return; the calling code is responsible for +applying inheritance in bulk (e.g. via the importer's existing +``_bulk_inherit_tags`` path or ``propagate_tags_on_product_sync``). + +This replaces the previous pattern of ``signals.m2m_changed.disconnect(...)`` +in importer hot loops, which was process-global and unsafe under threaded +gunicorn / Celery thread pools / ASGI threadpools (see PR description for +the full rationale). +""" +from __future__ import annotations + +import contextlib +import threading +from contextlib import contextmanager + +_state = threading.local() + + +def is_in_batch_mode() -> bool: + """Return True when the current thread is inside an active ``batch()``.""" + return bool(getattr(_state, "depth", 0)) + + +@contextmanager +def batch_mode(): + """ + Suppress per-instance inheritance signals for the calling thread. + + Usage: + with tag_inheritance.batch(): + # Bulk operations that would otherwise fire `make_inherited_tags_sticky` + # or `inherit_tags_on_instance` per row. + ... + + The context is reentrant; nested ``with`` blocks share the suppression + until the outermost block exits. State lives in ``threading.local()``, + so concurrent threads (and Celery workers in non-prefork pools) are + unaffected by other threads' batches. + """ + _state.depth = getattr(_state, "depth", 0) + 1 + try: + yield + finally: + _state.depth -= 1 + if _state.depth <= 0: + # Clean up the attribute so leak-free thread reuse stays simple. + with contextlib.suppress(AttributeError): + del _state.depth diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 6fe142f2e55..93ac1886930 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -4,6 +4,7 @@ from django.db.models import signals from django.dispatch import receiver +from dojo import tag_inheritance from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import Location, LocationFindingReference, LocationProductReference from dojo.models import Endpoint, Engagement, Finding, Product, Test @@ -32,6 +33,12 @@ def product_tags_post_add_remove(sender, instance, action, **kwargs): @receiver(signals.m2m_changed, sender=Location.tags.through) def make_inherited_tags_sticky(sender, instance, action, **kwargs): """Make sure inherited tags are added back in if they are removed""" + # Inside a `tag_inheritance.batch()` block the caller takes responsibility + # for applying inheritance in bulk; per-row signal work would defeat the + # purpose. This replaces the old `signals.m2m_changed.disconnect(...)` + # pattern, which was process-global and unsafe under threaded workers. + if tag_inheritance.is_in_batch_mode(): + return if action in {"post_add", "post_remove"}: if inherit_product_tags(instance): tag_list = [tag.name for tag in instance.tags.all()] diff --git a/unittests/test_tag_inheritance_perf.py b/unittests/test_tag_inheritance_perf.py index 8e9f9a5f792..2b1a3681a9a 100644 --- a/unittests/test_tag_inheritance_perf.py +++ b/unittests/test_tag_inheritance_perf.py @@ -493,7 +493,11 @@ def test_baseline_zap_scan_reimport_no_change_v3(self): # Phase A nudges these slightly downward (post_save gated on created=True # avoids re-running inheritance on no-op finding updates during reimport). # Pre-Phase-A: 1461/1319 import, 77/95 reimport. + # Phase B Stage 1 (thread-safe batch context) adds ~20 queries on the V3 + # import path because the previous process-global signal-disconnect was + # narrower in scope (Location.tags.through only). Net-positive trade for + # eliminating the threading bug; full Phase B reductions land in Stage 2. EXPECTED_ZAP_IMPORT_V2 = 1385 - EXPECTED_ZAP_IMPORT_V3 = 1243 + EXPECTED_ZAP_IMPORT_V3 = 1263 EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 69 EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 87