diff --git a/dojo/importers/default_importer.py b/dojo/importers/default_importer.py index 25d41c2b5cb..6ae1a88aa29 100644 --- a/dojo/importers/default_importer.py +++ b/dojo/importers/default_importer.py @@ -5,6 +5,7 @@ from django.db.models.query_utils import Q from django.urls import reverse +from dojo import tag_inheritance from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding import helper as finding_helper from dojo.importers.base_importer import BaseImporter, Parser @@ -114,29 +115,36 @@ def process_scan( # Note: for fresh imports, parse_findings() calls create_test() internally, # so self.test is guaranteed to be set after this call. parsed_findings = self.parse_findings(scan, parser) or [] - new_findings = self.process_findings(parsed_findings, **kwargs) - # Close any old findings in the processed list if the the user specified for that - # to occur in the form that is then passed to the kwargs - closed_findings = self.close_old_findings(self.test.finding_set.all(), **kwargs) - # Update the timestamps of the test object by looking at the findings imported - self.update_timestamps() - # Update the test meta - self.update_test_meta() - # Save the test and engagement for changes to take affect - self.test.save() - self.test.engagement.save() - # Create a test import history object to record the flags sent to the importer - # This operation will return None if the user does not have the import history - # feature enabled - test_import_history = self.update_import_history( - new_findings=new_findings, - closed_findings=closed_findings, - ) - # Apply tags to findings and endpoints/locations - self.apply_import_tags( - new_findings=new_findings, - closed_findings=closed_findings, - ) + # Open a tag-inheritance context. Signal handlers register touched + # instances into the context; the context auto-flushes (bulk-applies + # inherited tags) on exit. Mid-context `ctx.flush()` calls drain the + # accumulated set early — used before per-batch post-process dispatch + # so JIRA labels reflect the full tag state on first push. + with tag_inheritance.batch() as tag_ctx: + new_findings = self.process_findings(parsed_findings, **kwargs) + # Close any old findings in the processed list if the the user specified for that + # to occur in the form that is then passed to the kwargs + closed_findings = self.close_old_findings(self.test.finding_set.all(), **kwargs) + # Update the timestamps of the test object by looking at the findings imported + self.update_timestamps() + # Update the test meta + self.update_test_meta() + # Save the test and engagement for changes to take affect + self.test.save() + self.test.engagement.save() + # Create a test import history object to record the flags sent to the importer + # This operation will return None if the user does not have the import history + # feature enabled + test_import_history = self.update_import_history( + new_findings=new_findings, + closed_findings=closed_findings, + ) + # Apply tags to findings and endpoints/locations + self.apply_import_tags( + new_findings=new_findings, + closed_findings=closed_findings, + ) + # Inheritance auto-flushes on context exit above. # Send out some notifications to the user logger.debug("IMPORT_SCAN: Generating notifications") dojo_dispatch_task( @@ -269,6 +277,11 @@ def process_findings( findings_with_parser_tags.clear() finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() + # Drain the inheritance context BEFORE dispatching post-process + # so the JIRA push inside that task sees inherited tags on the + # findings (otherwise inheritance lands later, on context exit). + if (ctx := tag_inheritance.current()) is not None: + ctx.flush() logger.debug("process_findings: dispatching batch with push_to_jira=%s (batch_size=%d, is_final=%s)", push_to_jira, len(finding_ids_batch), is_final_finding) dojo_dispatch_task( diff --git a/dojo/importers/default_reimporter.py b/dojo/importers/default_reimporter.py index 06f06ca368f..083ae871365 100644 --- a/dojo/importers/default_reimporter.py +++ b/dojo/importers/default_reimporter.py @@ -6,6 +6,7 @@ from django.db.models.query_utils import Q import dojo.finding.helper as finding_helper +from dojo import tag_inheritance from dojo.celery_dispatch import dojo_dispatch_task from dojo.finding.deduplication import ( find_candidates_for_deduplication_hash, @@ -107,43 +108,48 @@ def process_scan( # Get the findings from the parser based on what methods the parser supplies # This could either mean traditional file parsing, or API pull parsing parsed_findings = self.parse_findings(scan, parser) or [] - ( - new_findings, - reactivated_findings, - findings_to_mitigate, - untouched_findings, - ) = self.process_findings(parsed_findings, **kwargs) - # Close any old findings in the processed list if the the user specified for that - # to occur in the form that is then passed to the kwargs - closed_findings = self.close_old_findings(findings_to_mitigate, **kwargs) - # Update the timestamps of the test object by looking at the findings imported - logger.debug("REIMPORT_SCAN: Updating test/engagement timestamps") - # Update the timestamps of the test object by looking at the findings imported - self.update_timestamps() - # Update the test meta - self.update_test_meta() - # Update the test tags - self.update_test_tags() - # Save the test and engagement for changes to take affect - self.test.save() - self.test.engagement.save() - logger.debug("REIMPORT_SCAN: Updating test tags") - # Create a test import history object to record the flags sent to the importer - # This operation will return None if the user does not have the import history - # feature enabled - test_import_history = self.update_import_history( - new_findings=new_findings, - closed_findings=closed_findings, - reactivated_findings=reactivated_findings, - untouched_findings=untouched_findings, - ) - # Apply tags to findings and endpoints - self.apply_import_tags( - new_findings=new_findings, - closed_findings=closed_findings, - reactivated_findings=reactivated_findings, - untouched_findings=untouched_findings, - ) + # Open a tag-inheritance context (auto-flushes on exit). Signals + # register touched instances; `ctx.flush()` mid-loop drains them + # before each post-process dispatch so JIRA labels are correct. + with tag_inheritance.batch() as tag_ctx: + ( + new_findings, + reactivated_findings, + findings_to_mitigate, + untouched_findings, + ) = self.process_findings(parsed_findings, **kwargs) + # Close any old findings in the processed list if the the user specified for that + # to occur in the form that is then passed to the kwargs + closed_findings = self.close_old_findings(findings_to_mitigate, **kwargs) + # Update the timestamps of the test object by looking at the findings imported + logger.debug("REIMPORT_SCAN: Updating test/engagement timestamps") + # Update the timestamps of the test object by looking at the findings imported + self.update_timestamps() + # Update the test meta + self.update_test_meta() + # Update the test tags + self.update_test_tags() + # Save the test and engagement for changes to take affect + self.test.save() + self.test.engagement.save() + logger.debug("REIMPORT_SCAN: Updating test tags") + # Create a test import history object to record the flags sent to the importer + # This operation will return None if the user does not have the import history + # feature enabled + test_import_history = self.update_import_history( + new_findings=new_findings, + closed_findings=closed_findings, + reactivated_findings=reactivated_findings, + untouched_findings=untouched_findings, + ) + # Apply tags to findings and endpoints + self.apply_import_tags( + new_findings=new_findings, + closed_findings=closed_findings, + reactivated_findings=reactivated_findings, + untouched_findings=untouched_findings, + ) + # Inheritance auto-flushes on context exit above. # Send out som notifications to the user logger.debug("REIMPORT_SCAN: Generating notifications") updated_count = ( @@ -427,6 +433,11 @@ def process_findings( findings_with_parser_tags.clear() finding_ids_batch = list(batch_finding_ids) batch_finding_ids.clear() + # Drain the inheritance context BEFORE dispatching + # post-process so the JIRA push inside that task sees + # inherited tags on the findings. + if (ctx := tag_inheritance.current()) is not None: + ctx.flush() dojo_dispatch_task( finding_helper.post_process_findings_batch, finding_ids_batch, diff --git a/dojo/importers/location_manager.py b/dojo/importers/location_manager.py index 94999538fc1..935d5159427 100644 --- a/dojo/importers/location_manager.py +++ b/dojo/importers/location_manager.py @@ -7,14 +7,13 @@ from django.core.exceptions import ValidationError from django.db import transaction -from django.db.models import signals from django.utils import timezone +from dojo import tag_inheritance from dojo.importers.base_location_manager import BaseLocationManager from dojo.location.models import AbstractLocation, Location, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.models import Product, _manage_inherited_tags -from dojo.tags_signals import make_inherited_tags_sticky from dojo.tools.locations import LocationData from dojo.url.models import URL from dojo.utils import get_system_setting @@ -551,10 +550,17 @@ def _get_tags(tags_field: TagField) -> dict[int, set[str]]: existing_inherited_by_location: dict[int, set[str]] = _get_tags(Location.inherited_tags) existing_tags_by_location: dict[int, set[str]] = _get_tags(Location.tags) - # Perform the bulk updates. First, though, disconnect the make_inherited_tags_sticky signal on Location.tags - # while updating, otherwise each (inherited_)tags.set() will trigger, defeating the purpose of this bulk update. - disconnected = signals.m2m_changed.disconnect(make_inherited_tags_sticky, sender=Location.tags.through) - try: + # Perform the bulk updates inside a `tag_inheritance.batch()` context. + # While the batch is active, signal handlers in `dojo/tags_signals.py` + # short-circuit per-row inheritance work that would otherwise fire on + # every `(inherited_)tags.set()` and defeat the bulk update. + # + # This replaces a previous `signals.m2m_changed.disconnect(...)` / + # `connect(...)` dance which was process-global and therefore unsafe + # under threaded gunicorn / Celery thread pools / ASGI threadpools: + # while disconnected, every thread in the process lost sticky + # enforcement. Thread-local batch state avoids that hazard. + with tag_inheritance.batch(): for location in locations: target_tag_names: set[str] = set() for pid in product_ids_by_location[location.id]: @@ -573,6 +579,3 @@ def _get_tags(tags_field: TagField) -> dict[int, set[str]]: list(target_tag_names), potentially_existing_tags=existing_tags_by_location[location.id], ) - finally: - if disconnected: - signals.m2m_changed.connect(make_inherited_tags_sticky, sender=Location.tags.through) diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py index cdb0750f317..331ebc534b6 100644 --- a/dojo/product/helpers.py +++ b/dojo/product/helpers.py @@ -1,12 +1,14 @@ import contextlib import logging +from collections import defaultdict from django.conf import settings from django.db.models import Q from dojo.celery import app -from dojo.location.models import Location +from dojo.location.models import Location, LocationFindingReference, LocationProductReference from dojo.models import Endpoint, Engagement, Finding, Product, Test +from dojo.tag_utils import bulk_add_tag_mapping, bulk_remove_tags_from_instances logger = logging.getLogger(__name__) @@ -19,35 +21,194 @@ def propagate_tags_on_product(product_id, *args, **kwargs): def propagate_tags_on_product_sync(product): - # enagagements + """ + Bulk-apply Product tag changes to all children using through-table SQL. + + Replaces the previous per-row `.save()` loop. For every child model owned + by the product (Engagement, Test, Finding, plus Endpoint or Location + depending on the V3_FEATURE_LOCATIONS flag), reads the existing + `inherited_tags` per child in one query, computes the diff against the + Product's current tags, and applies adds/removes via the bulk tag + helpers. Both `tags` and `inherited_tags` fields are kept in sync. + """ + target_names = {tag.name for tag in product.tags.all()} + logger.debug("Propagating tags from %s to all engagements", product) - propagate_tags_on_object_list(Engagement.objects.filter(product=product)) - # tests + _sync_inheritance_for_qs( + Engagement.objects.filter(product=product), + target_names_per_child=lambda _child: target_names, + ) logger.debug("Propagating tags from %s to all tests", product) - propagate_tags_on_object_list(Test.objects.filter(engagement__product=product)) - # findings + _sync_inheritance_for_qs( + Test.objects.filter(engagement__product=product), + target_names_per_child=lambda _child: target_names, + ) logger.debug("Propagating tags from %s to all findings", product) - propagate_tags_on_object_list(Finding.objects.filter(test__engagement__product=product)) + _sync_inheritance_for_qs( + Finding.objects.filter(test__engagement__product=product), + target_names_per_child=lambda _child: target_names, + ) if settings.V3_FEATURE_LOCATIONS: - # Locations logger.debug("Propagating tags from %s to all locations", product) - propagate_tags_on_object_list( - Location.objects.filter( - # Locations linked directly to a product via LocationProductReference - Q(products__product=product) - # Locations linked indirectly to a product via LocationFindingReference - | Q(findings__finding__test__engagement__product=product), - ).distinct(), + # Materialize once so we can build a precomputed + # {location_id: set[tag_name]} map without re-evaluating the queryset + # or paying N+1 in `_location_target_names`. + locations = list(Location.objects.filter( + Q(products__product=product) + | Q(findings__finding__test__engagement__product=product), + ).distinct()) + location_target_names = _build_location_target_names_map( + [loc.pk for loc in locations], + ) + _sync_inheritance_for_qs( + locations, + target_names_per_child=lambda loc: location_target_names.get(loc.pk, set()), ) else: - # TODO: Delete this after the move to Locations - # endpoints logger.debug("Propagating tags from %s to all endpoints", product) - propagate_tags_on_object_list(Endpoint.objects.filter(product=product)) + _sync_inheritance_for_qs( + Endpoint.objects.filter(product=product), + target_names_per_child=lambda _child: target_names, + ) + + +def _build_location_target_names_map(location_ids): + """ + Bulk-compute {location_id: set[tag_name]} for the given locations. + + Replaces the per-location `_location_target_names` callable, which issued + one `Product.objects.filter(...).distinct()` query plus N `.tags.all()` + queries per location. Now: 3 queries total regardless of fan-out. + """ + if not location_ids: + return {} + + location_to_products: dict[int, set[int]] = defaultdict(set) + for loc_id, prod_id in LocationProductReference.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", "product_id"): + location_to_products[loc_id].add(prod_id) + for loc_id, prod_id in LocationFindingReference.objects.filter( + location_id__in=location_ids, + ).values_list("location_id", "finding__test__engagement__product_id"): + if prod_id is not None: + location_to_products[loc_id].add(prod_id) + + all_product_ids = {pid for pids in location_to_products.values() for pid in pids} + if not all_product_ids: + return {loc_id: set() for loc_id in location_ids} + + product_tags_through = Product.tags.through + tag_model = Product.tags.tag_model + tag_field_name = tag_model._meta.model_name + product_to_tag_names: dict[int, set[str]] = defaultdict(set) + for prod_id, tag_name in product_tags_through.objects.filter( + product_id__in=all_product_ids, + ).values_list("product_id", f"{tag_field_name}__name"): + product_to_tag_names[prod_id].add(tag_name) + + return { + loc_id: { + name + for pid in pids + for name in product_to_tag_names.get(pid, set()) + } + for loc_id, pids in location_to_products.items() + } + + +def _sync_inheritance_for_qs(queryset, *, target_names_per_child): + """ + Sync inherited_tags + tags for every child in `queryset` to its target tag set. + + target_names_per_child: callable(child) -> set[str]. + + Issues bulk SQL: one through-table read for current inherited_tags, then + bulk add/remove on `tags` and `inherited_tags` fields. + """ + children = list(queryset) + if not children: + return + + model_class = type(children[0]) + inherited_field = model_class._meta.get_field("inherited_tags") + inherited_through = inherited_field.remote_field.through + inherited_tag_model = inherited_field.related_model + + # Resolve through-table FK column for the source side. + source_field_name = None + for field in inherited_through._meta.fields: + if hasattr(field, "remote_field") and field.remote_field and field.remote_field.model == model_class: + source_field_name = field.name + break + + child_ids = [c.pk for c in children] + # One query: pull every (child_id, tag_name) pair from the inherited_tags through table. + existing_pairs = inherited_through.objects.filter( + **{f"{source_field_name}__in": child_ids}, + ).values_list(source_field_name, f"{inherited_tag_model._meta.model_name}__name") + + old_inherited_by_child: dict[int, set[str]] = defaultdict(set) + for child_id, tag_name in existing_pairs: + old_inherited_by_child[child_id].add(tag_name) + + # Compute per-child diff and bucket by tag name. Two diffs are computed: + # - inherited_tags add/remove: keeps the inherited_tags M2M in sync + # with the target. + # - tags re-merge: ensures every target name is also present on `tags`, + # even when inherited_tags already matched. This is the bulk + # equivalent of `make_inherited_tags_sticky` enforcement, needed for + # the importer hot path where `test.tags.set([...])` overwrites the + # full tag list inside a `tag_inheritance.batch()` block. + add_map: dict[str, list] = defaultdict(list) + remove_map: dict[str, list] = defaultdict(list) + target_per_child: dict[int, set[str]] = {} + for child in children: + target = target_names_per_child(child) + target_per_child[child.pk] = target + old = old_inherited_by_child.get(child.pk, set()) + for name in target - old: + add_map[name].append(child) + for name in old - target: + remove_map[name].append(child) + + # Apply adds. Both `tags` and `inherited_tags` get the same set of new + # inherited names — `_manage_inherited_tags` did the same. + if add_map: + bulk_add_tag_mapping(add_map, tag_field_name="inherited_tags") + bulk_add_tag_mapping(add_map, tag_field_name="tags") + + # Apply removes. + for name, instances in remove_map.items(): + bulk_remove_tags_from_instances(name, instances, tag_field_name="inherited_tags") + bulk_remove_tags_from_instances(name, instances, tag_field_name="tags") + + # Bulk re-merge: ensure every target name is present on `tags`. We need + # this for the importer hot path where `tags.set([...])` inside a + # `tag_inheritance.batch()` can wipe inherited names from `tags` while + # `inherited_tags` stays in sync (so the diff above is empty). + # + # Read the current `tags` per child so we only write rows that are + # actually missing — without this guard the re-merge becomes O(target * + # children) bulk_create attempts for every product-tag toggle. + tags_field = model_class._meta.get_field("tags") + tags_through = tags_field.remote_field.through + tags_tag_model = tags_field.related_model + existing_tags_pairs = tags_through.objects.filter( + **{f"{source_field_name}__in": child_ids}, + ).values_list(source_field_name, f"{tags_tag_model._meta.model_name}__name") + current_tags_by_child: dict[int, set[str]] = defaultdict(set) + for child_id, tag_name in existing_tags_pairs: + current_tags_by_child[child_id].add(tag_name) -def propagate_tags_on_object_list(object_list): - for obj in object_list: - if obj and obj.id is not None: - logger.debug(f"\tPropagating tags to {type(obj)} - {obj}") - obj.save() + remerge_map: dict[str, list] = defaultdict(list) + for child in children: + target = target_per_child[child.pk] + current = current_tags_by_child.get(child.pk, set()) + # Skip names already added by the diff above; only fix true drift. + already_added = {name for name, lst in add_map.items() if child in lst} + for name in target - current - already_added: + remerge_map[name].append(child) + if remerge_map: + bulk_add_tag_mapping(remerge_map, tag_field_name="tags") diff --git a/dojo/tag_inheritance.py b/dojo/tag_inheritance.py new file mode 100644 index 00000000000..758b702a8a3 --- /dev/null +++ b/dojo/tag_inheritance.py @@ -0,0 +1,182 @@ +""" +Tag inheritance — watson-style context manager. + +Pattern mirrors `watson.search.SearchContextManager`: signal handlers +register touched instances into the active context instead of running +per-row inheritance work; the context flushes them in bulk on +``flush()`` (called explicitly mid-batch) and on context exit. + +Usage: + with tag_inheritance.batch() as ctx: + # bulk operations create/modify many instances + ... + ctx.flush() # optional, mid-batch sync (e.g. before JIRA push) + ... + # auto-flushes on outermost exit + +The context lives in ``threading.local``, so concurrent threads (and +Celery workers in non-prefork pools) are unaffected by other threads' +batches. +""" +from __future__ import annotations + +import logging +import threading +from collections import defaultdict +from contextlib import contextmanager + +logger = logging.getLogger(__name__) + +_state = threading.local() + + +class TagInheritanceContext: + + """ + Per-thread registrar for instances whose inherited tags need + re-syncing in bulk. + + Layout: ``{product_id: {model_class: {pk, ...}}}`` for single-product + children (Engagement / Test / Finding / Endpoint), plus a separate + set of Location pks (locations are linked to many products via + LocationProductReference / LocationFindingReference, so their target + tag set is the union of all related products' tags). + + On ``flush()``: one bulk diff per (product, model) group via + ``_sync_inheritance_for_qs``; locations route through the bulk + target-map helper. + """ + + def __init__(self): + self._depth = 0 + # product_id -> model_class -> set[pk] + self._touched_by_product: dict[int, dict[type, set[int]]] = defaultdict(lambda: defaultdict(set)) + # Cached resolved Product instances so flush doesn't re-read. + self._product_by_id: dict[int, object] = {} + # Locations are multi-product; tracked separately and resolved at flush. + self._touched_locations: set[int] = set() + # System-wide inheritance flag is read from the DB and cached for + # the lifetime of the context. Per-product flags are read off the + # in-memory product instance (no DB cost). + self._system_inheritance: bool | None = None + + def is_active(self) -> bool: + return self._depth > 0 + + def system_inheritance_enabled(self) -> bool: + if self._system_inheritance is None: + from dojo.utils import get_system_setting # noqa: PLC0415 + self._system_inheritance = bool(get_system_setting("enable_product_tag_inheritance")) + return self._system_inheritance + + def add(self, instance) -> None: + """ + Register an instance for bulk-sync at next flush. + + For Location: always register (filtered at flush time, since + per-location inheritance check would cost a DB query each). + + For other models: resolve product upfront (in-memory FK chain), + skip when inheritance is disabled for that product. Stays cheap + on inheritance-off products. + """ + if instance is None or getattr(instance, "pk", None) is None: + return + + from dojo.location.models import Location # noqa: PLC0415 + if isinstance(instance, Location): + self._touched_locations.add(instance.pk) + return + + from dojo.tags_signals import get_products # noqa: PLC0415 + for product in get_products(instance): + if product is None: + continue + if not getattr(product, "enable_product_tag_inheritance", False): + if not self.system_inheritance_enabled(): + continue + self._touched_by_product[product.id][type(instance)].add(instance.pk) + self._product_by_id[product.id] = product + + def flush(self) -> None: + """ + Bulk-sync inherited tags for every registered instance, then + clear the registry. Idempotent and cheap when nothing was + touched. + """ + if not self._touched_by_product and not self._touched_locations: + return + # Local imports to avoid circulars at module import time. + from dojo.location.models import Location # noqa: PLC0415 + from dojo.product.helpers import ( # noqa: PLC0415 + _build_location_target_names_map, + _sync_inheritance_for_qs, + ) + + touched_by_product = self._touched_by_product + product_by_id = self._product_by_id + touched_locations = self._touched_locations + self._touched_by_product = defaultdict(lambda: defaultdict(set)) + self._product_by_id = {} + self._touched_locations = set() + + for product_id, model_pks in touched_by_product.items(): + product = product_by_id.get(product_id) + if product is None: + continue + target_tag_names = {tag.name for tag in product.tags.all()} + for model_class, pks in model_pks.items(): + if not pks: + continue + _sync_inheritance_for_qs( + model_class.objects.filter(pk__in=pks), + target_names_per_child=lambda _c, _t=target_tag_names: _t, + ) + + if touched_locations: + target_map = _build_location_target_names_map(list(touched_locations)) + _sync_inheritance_for_qs( + Location.objects.filter(pk__in=touched_locations), + target_names_per_child=lambda loc, _m=target_map: _m.get(loc.pk, set()), + ) + + +def current() -> TagInheritanceContext | None: + """Return the active context for this thread, if any.""" + return getattr(_state, "ctx", None) + + +def is_in_batch() -> bool: + """Return True when the current thread is inside an active ``batch()``.""" + ctx = current() + return ctx is not None and ctx.is_active() + + +@contextmanager +def batch(): + """ + Open a tag-inheritance context for the calling thread. + + Inside the context, signal handlers register touched instances + instead of running per-row inheritance. On exit, the context + auto-flushes (bulk-applies inheritance for every touched instance). + + Reentrant: nested ``with`` blocks share the context until the + outermost block exits. + """ + ctx = getattr(_state, "ctx", None) + owner = ctx is None + if owner: + ctx = TagInheritanceContext() + _state.ctx = ctx + ctx._depth += 1 + try: + yield ctx + finally: + ctx._depth -= 1 + if ctx._depth <= 0: + try: + ctx.flush() + finally: + if owner: + del _state.ctx diff --git a/dojo/tag_utils.py b/dojo/tag_utils.py index 87ed6845961..8b2509e2dab 100644 --- a/dojo/tag_utils.py +++ b/dojo/tag_utils.py @@ -164,6 +164,118 @@ def bulk_add_tags_to_instances(tag_or_tags, instances, tag_field_name: str = "ta return total_created +def bulk_remove_tags_from_instances(tag_or_tags, instances, tag_field_name: str = "tags", batch_size: int | None = None) -> int: + """ + Efficiently remove tag(s) from many model instances. + + Symmetric to ``bulk_add_tags_to_instances``: + + - tag_or_tags: a single string, an iterable of strings or tag objects, or a Tagulous edit string. + - instances: QuerySet or list of model instances of the same class. + - tag_field_name: name of the TagField on the model (default: ``"tags"``). + - Decrements ``tag.count`` for every removed (instance, tag) pair. + - Deletes through-model rows in one DELETE per tag (or batched). + - Clears Django prefetch caches on the input instances so subsequent access reloads from DB. + + Returns the total number of relationships removed across all provided tags. + Tags that do not exist or are not currently associated with any instance are silently skipped. + """ + if batch_size is None: + batch_size = getattr(settings, "TAG_BULK_ADD_BATCH_SIZE", 1000) + + if hasattr(instances, "model"): + instances = list(instances) + + if not instances: + return 0 + + model_class = instances[0].__class__ + + # Mirror the Product safety check from bulk_add_tags_to_instances. Removing + # tags from a Product would normally trigger inheritance propagation via + # m2m_changed signals; this helper bypasses signals, so disallow it. + if model_class is Product: + msg = "bulk_remove_tags_from_instances: Product instances are not supported; use Product.tags.remove() or a propagation-aware helper" + raise ValueError(msg) + + try: + tag_field = model_class._meta.get_field(tag_field_name) + except Exception: + msg = f"Model {model_class.__name__} does not have field '{tag_field_name}'" + raise ValueError(msg) + + if not hasattr(tag_field, "tag_options"): + msg = f"Field '{tag_field_name}' is not a TagField" + raise ValueError(msg) + + tag_model = tag_field.related_model + through_model = tag_field.remote_field.through + + # Normalize input into a list of tag names (mirrors bulk_add_tags_to_instances). + tag_names: list[str] = [] + try: + if isinstance(tag_or_tags, str): + space_delimiter = getattr(tag_field, "tag_options", None).space_delimiter if hasattr(tag_field, "tag_options") else False + tag_names = parse_tags(tag_or_tags, space_delimiter=space_delimiter) + elif isinstance(tag_or_tags, Iterable): + tag_names = [getattr(t, "name", str(t)) for t in tag_or_tags] + else: + tag_names = [str(tag_or_tags)] + except Exception: + tag_names = [str(tag_or_tags)] + + # Resolve through-model FK names dynamically (no hard-coding). + through_fields = {f.name: f for f in through_model._meta.fields} + source_field_name = None + target_field_name = None + for field_name, field in through_fields.items(): + if hasattr(field, "remote_field") and field.remote_field: + if field.remote_field.model == model_class: + source_field_name = field_name + elif field.remote_field.model == tag_model: + target_field_name = field_name + + total_removed = 0 + + for single_tag_name in tag_names: + if not single_tag_name: + continue + + # Resolve the tag — skip silently if it doesn't exist (nothing to remove). + if tag_field.tag_options.case_sensitive: + tag = tag_model.objects.filter(name=single_tag_name).first() + else: + tag = tag_model.objects.filter(name__iexact=single_tag_name).first() + if tag is None: + continue + + for i in range(0, len(instances), batch_size): + batch_instances = instances[i:i + batch_size] + batch_ids = [instance.pk for instance in batch_instances] + + with transaction.atomic(): + # One DELETE per tag-batch. Returns the deleted-row count. + deleted_count, _ = through_model.objects.filter( + **{target_field_name: tag.pk}, + **{f"{source_field_name}__in": batch_ids}, + ).delete() + + if deleted_count: + total_removed += deleted_count + # Decrement the Tagulous-maintained count to avoid drift. + tag_model.objects.filter(pk=tag.pk).update( + count=models.F("count") - deleted_count, + ) + + # Invalidate prefetch caches so callers see the new state. + for instance in batch_instances: + prefetch_cache = getattr(instance, "_prefetched_objects_cache", None) + if prefetch_cache is not None: + prefetch_cache.pop(tag_field_name, None) + + return total_removed + + def bulk_add_tag_mapping( tag_to_instances: dict[str, list], tag_field_name: str = "tags", @@ -410,4 +522,4 @@ def bulk_remove_all_tags(model_class, instance_ids_qs): ) -__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags"] +__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags", "bulk_remove_tags_from_instances"] diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 0fea7ae8ad5..aef1bfb0e14 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -4,6 +4,7 @@ from django.db.models import signals from django.dispatch import receiver +from dojo import tag_inheritance from dojo.celery_dispatch import dojo_dispatch_task from dojo.location.models import Location, LocationFindingReference, LocationProductReference from dojo.models import Endpoint, Engagement, Finding, Product, Test @@ -32,11 +33,18 @@ def product_tags_post_add_remove(sender, instance, action, **kwargs): @receiver(signals.m2m_changed, sender=Location.tags.through) def make_inherited_tags_sticky(sender, instance, action, **kwargs): """Make sure inherited tags are added back in if they are removed""" - if action in {"post_add", "post_remove"}: - if inherit_product_tags(instance): - tag_list = [tag.name for tag in instance.tags.all()] - if propagate_inheritance(instance, tag_list=tag_list): - instance.inherit_tags(tag_list) + if action not in {"post_add", "post_remove"}: + return + # Inside a `tag_inheritance.batch()` context, register the instance + # for bulk-sync at flush/exit instead of running per-row inheritance. + ctx = tag_inheritance.current() + if ctx is not None and ctx.is_active(): + ctx.add(instance) + return + if inherit_product_tags(instance): + tag_list = [tag.name for tag in instance.tags.all()] + if propagate_inheritance(instance, tag_list=tag_list): + instance.inherit_tags(tag_list) def inherit_instance_tags(instance): @@ -59,12 +67,30 @@ def inherit_linked_instance_tags(instance: LocationFindingReference | LocationPr @receiver(signals.post_save, sender=Finding) @receiver(signals.post_save, sender=Location) def inherit_tags_on_instance(sender, instance, created, **kwargs): + # Only inherit on creation. The previous behavior fired on every save + # (create OR update), repeatedly re-applying inherited tags to children + # whose tag state had not changed. Sticky enforcement on user-driven + # tag edits is handled by `make_inherited_tags_sticky` (m2m_changed). + if not created: + return + # Inside a `tag_inheritance.batch()` context, register the new instance + # for bulk-sync at flush/exit instead of running per-row inheritance. + ctx = tag_inheritance.current() + if ctx is not None and ctx.is_active(): + ctx.add(instance) + return inherit_instance_tags(instance) @receiver(signals.post_save, sender=LocationFindingReference) @receiver(signals.post_save, sender=LocationProductReference) def inherit_tags_on_linked_instance(sender, instance, created, **kwargs): + # Linked refs (LocationFinding/LocationProductReference) bind a Location + # to a Finding/Product. Register the underlying Location for bulk-sync. + ctx = tag_inheritance.current() + if ctx is not None and ctx.is_active(): + ctx.add(instance.location) + return inherit_linked_instance_tags(instance) diff --git a/unittests/test_tag_inheritance_perf.py b/unittests/test_tag_inheritance_perf.py index 450a3d2ab89..12a1f69334f 100644 --- a/unittests/test_tag_inheritance_perf.py +++ b/unittests/test_tag_inheritance_perf.py @@ -357,10 +357,14 @@ def test_baseline_product_tag_remove_propagates_to_100_locations_v3(self): # the appropriate mode so all variants execute in a single suite run. # Findings-only scenarios. - EXPECTED_PRODUCT_TAG_ADD_100_V2 = 4758 - EXPECTED_PRODUCT_TAG_ADD_100_V3 = 4759 - EXPECTED_PRODUCT_TAG_REMOVE_100_V2 = 4540 - EXPECTED_PRODUCT_TAG_REMOVE_100_V3 = 4541 + # Pre-Phase-A V2: 4758 add, 4540 remove. V3: 4759/4541. + # Phase A bulk-propagate drops these dramatically. + # Phase B Stage 2 adds ~3 queries to read current `tags` for the bulk + # re-merge step (compensates for tags wiped inside batch contexts). + EXPECTED_PRODUCT_TAG_ADD_100_V2 = 94 + EXPECTED_PRODUCT_TAG_ADD_100_V3 = 94 + EXPECTED_PRODUCT_TAG_REMOVE_100_V2 = 56 + EXPECTED_PRODUCT_TAG_REMOVE_100_V3 = 56 EXPECTED_CREATE_ONE_FINDING_V2 = 64 EXPECTED_CREATE_ONE_FINDING_V3 = 64 @@ -372,13 +376,20 @@ def test_baseline_product_tag_remove_propagates_to_100_locations_v3(self): EXPECTED_FINDING_REMOVE_INHERITED_V2 = 44 EXPECTED_FINDING_REMOVE_INHERITED_V3 = 44 - # V2 endpoint paths (Endpoints have no V3 counterpart in this class). - EXPECTED_PRODUCT_TAG_ADD_100_ENDPOINTS = 3958 - EXPECTED_PRODUCT_TAG_REMOVE_100_ENDPOINTS = 3740 + # V2 endpoint paths. Pre-Phase-A: 3958 add, 3740 remove. + # Phase B Stage 2 raises endpoint add 91 -> 194 because the eager Celery + # propagate dispatched by m2m_changed and the explicit + # propagate_tags_on_product_sync call both pay the new tags-read for + # bulk re-merge. The bulk-propagate + batch-context + Location precompute + # commits later in Stage 2 collapse this back to 94. Will go further down + # in Stages 3+4+5 when the duplicate inherited_tags M2M is dropped. + EXPECTED_PRODUCT_TAG_ADD_100_ENDPOINTS = 94 + EXPECTED_PRODUCT_TAG_REMOVE_100_ENDPOINTS = 56 - # V3 location paths (LocationManager has no V2 counterpart in this class). - EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 4531 - EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 4307 + # V3 location paths. Pre-Phase-A: 4532 add, 4307 remove. + # Phase B Stage 2 + location precompute: bulk-built target-name map. + EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 123 + EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 73 @override_settings( @@ -488,7 +499,18 @@ def test_baseline_zap_scan_reimport_no_change_v3(self): # Pinned baselines per mode. Each test forces its own V3_FEATURE_LOCATIONS # via @override_settings so all four import paths run in a single suite # invocation regardless of the ambient `DD_V3_FEATURE_LOCATIONS` env var. - EXPECTED_ZAP_IMPORT_V2 = 1461 - EXPECTED_ZAP_IMPORT_V3 = 1319 - EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 77 - EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 95 + # Phase A nudges these slightly downward (post_save gated on created=True + # avoids re-running inheritance on no-op finding updates during reimport). + # Pre-Phase-A: 1461/1319 import, 77/95 reimport. + # Phase B Stage 1 (thread-safe batch context) adds ~20 queries on the V3 + # import path because the previous process-global signal-disconnect was + # narrower in scope (Location.tags.through only). + # Phase B Stage 2 (importer-wide batch + flush_for_product) drops import + # ~27%/22% (signal cascade replaced by single bulk propagate). Reimport + # rises because flush always runs; bulk re-merge has a fixed cost even + # when there's no work. Stages 3+4+5 (drop duplicate inherited_tags M2M) + # will collapse the reimport cost. + EXPECTED_ZAP_IMPORT_V2 = 1000 + EXPECTED_ZAP_IMPORT_V3 = 941 + EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 69 + EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 87