diff --git a/dojo/product/helpers.py b/dojo/product/helpers.py index cdb0750f317..7bbb2937103 100644 --- a/dojo/product/helpers.py +++ b/dojo/product/helpers.py @@ -1,5 +1,6 @@ import contextlib import logging +from collections import defaultdict from django.conf import settings from django.db.models import Q @@ -7,6 +8,7 @@ from dojo.celery import app from dojo.location.models import Location from dojo.models import Endpoint, Engagement, Finding, Product, Test +from dojo.tag_utils import bulk_add_tag_mapping, bulk_remove_tags_from_instances logger = logging.getLogger(__name__) @@ -19,35 +21,115 @@ def propagate_tags_on_product(product_id, *args, **kwargs): def propagate_tags_on_product_sync(product): - # enagagements + """ + Bulk-apply Product tag changes to all children using through-table SQL. + + Replaces the previous per-row `.save()` loop. For every child model owned + by the product (Engagement, Test, Finding, plus Endpoint or Location + depending on the V3_FEATURE_LOCATIONS flag), reads the existing + `inherited_tags` per child in one query, computes the diff against the + Product's current tags, and applies adds/removes via the bulk tag + helpers. Both `tags` and `inherited_tags` fields are kept in sync. + """ + target_names = {tag.name for tag in product.tags.all()} + logger.debug("Propagating tags from %s to all engagements", product) - propagate_tags_on_object_list(Engagement.objects.filter(product=product)) - # tests + _sync_inheritance_for_qs( + Engagement.objects.filter(product=product), + target_names_per_child=lambda _child: target_names, + ) logger.debug("Propagating tags from %s to all tests", product) - propagate_tags_on_object_list(Test.objects.filter(engagement__product=product)) - # findings + _sync_inheritance_for_qs( + Test.objects.filter(engagement__product=product), + target_names_per_child=lambda _child: target_names, + ) logger.debug("Propagating tags from %s to all findings", product) - propagate_tags_on_object_list(Finding.objects.filter(test__engagement__product=product)) + _sync_inheritance_for_qs( + Finding.objects.filter(test__engagement__product=product), + target_names_per_child=lambda _child: target_names, + ) if settings.V3_FEATURE_LOCATIONS: - # Locations logger.debug("Propagating tags from %s to all locations", product) - propagate_tags_on_object_list( - Location.objects.filter( - # Locations linked directly to a product via LocationProductReference - Q(products__product=product) - # Locations linked indirectly to a product via LocationFindingReference - | Q(findings__finding__test__engagement__product=product), - ).distinct(), + location_qs = Location.objects.filter( + Q(products__product=product) + | Q(findings__finding__test__engagement__product=product), + ).distinct() + # Locations can be linked to multiple products, so the inherited target + # is the union of every related product's tags. Compute per-location. + _sync_inheritance_for_qs( + location_qs, + target_names_per_child=_location_target_names, ) else: - # TODO: Delete this after the move to Locations - # endpoints logger.debug("Propagating tags from %s to all endpoints", product) - propagate_tags_on_object_list(Endpoint.objects.filter(product=product)) + _sync_inheritance_for_qs( + Endpoint.objects.filter(product=product), + target_names_per_child=lambda _child: target_names, + ) + + +def _location_target_names(location): + names: set[str] = set() + for related_product in location.all_related_products(): + if related_product is None: + continue + names.update(tag.name for tag in related_product.tags.all()) + return names + + +def _sync_inheritance_for_qs(queryset, *, target_names_per_child): + """ + Sync inherited_tags + tags for every child in `queryset` to its target tag set. + + target_names_per_child: callable(child) -> set[str]. + + Issues bulk SQL: one through-table read for current inherited_tags, then + bulk add/remove on `tags` and `inherited_tags` fields. + """ + children = list(queryset) + if not children: + return + + model_class = type(children[0]) + inherited_field = model_class._meta.get_field("inherited_tags") + inherited_through = inherited_field.remote_field.through + inherited_tag_model = inherited_field.related_model + + # Resolve through-table FK column for the source side. + source_field_name = None + for field in inherited_through._meta.fields: + if hasattr(field, "remote_field") and field.remote_field and field.remote_field.model == model_class: + source_field_name = field.name + break + + child_ids = [c.pk for c in children] + # One query: pull every (child_id, tag_name) pair from the inherited_tags through table. + existing_pairs = inherited_through.objects.filter( + **{f"{source_field_name}__in": child_ids}, + ).values_list(source_field_name, f"{inherited_tag_model._meta.model_name}__name") + + old_inherited_by_child: dict[int, set[str]] = defaultdict(set) + for child_id, tag_name in existing_pairs: + old_inherited_by_child[child_id].add(tag_name) + + # Compute per-child diff and bucket by tag name. + add_map: dict[str, list] = defaultdict(list) + remove_map: dict[str, list] = defaultdict(list) + for child in children: + target = target_names_per_child(child) + old = old_inherited_by_child.get(child.pk, set()) + for name in target - old: + add_map[name].append(child) + for name in old - target: + remove_map[name].append(child) + # Apply adds. Both `tags` and `inherited_tags` get the same set of new + # inherited names — `_manage_inherited_tags` did the same. + if add_map: + bulk_add_tag_mapping(add_map, tag_field_name="inherited_tags") + bulk_add_tag_mapping(add_map, tag_field_name="tags") -def propagate_tags_on_object_list(object_list): - for obj in object_list: - if obj and obj.id is not None: - logger.debug(f"\tPropagating tags to {type(obj)} - {obj}") - obj.save() + # Apply removes. + for name, instances in remove_map.items(): + bulk_remove_tags_from_instances(name, instances, tag_field_name="inherited_tags") + bulk_remove_tags_from_instances(name, instances, tag_field_name="tags") diff --git a/dojo/tag_utils.py b/dojo/tag_utils.py index 87ed6845961..8b2509e2dab 100644 --- a/dojo/tag_utils.py +++ b/dojo/tag_utils.py @@ -164,6 +164,118 @@ def bulk_add_tags_to_instances(tag_or_tags, instances, tag_field_name: str = "ta return total_created +def bulk_remove_tags_from_instances(tag_or_tags, instances, tag_field_name: str = "tags", batch_size: int | None = None) -> int: + """ + Efficiently remove tag(s) from many model instances. + + Symmetric to ``bulk_add_tags_to_instances``: + + - tag_or_tags: a single string, an iterable of strings or tag objects, or a Tagulous edit string. + - instances: QuerySet or list of model instances of the same class. + - tag_field_name: name of the TagField on the model (default: ``"tags"``). + - Decrements ``tag.count`` for every removed (instance, tag) pair. + - Deletes through-model rows in one DELETE per tag (or batched). + - Clears Django prefetch caches on the input instances so subsequent access reloads from DB. + + Returns the total number of relationships removed across all provided tags. + Tags that do not exist or are not currently associated with any instance are silently skipped. + """ + if batch_size is None: + batch_size = getattr(settings, "TAG_BULK_ADD_BATCH_SIZE", 1000) + + if hasattr(instances, "model"): + instances = list(instances) + + if not instances: + return 0 + + model_class = instances[0].__class__ + + # Mirror the Product safety check from bulk_add_tags_to_instances. Removing + # tags from a Product would normally trigger inheritance propagation via + # m2m_changed signals; this helper bypasses signals, so disallow it. + if model_class is Product: + msg = "bulk_remove_tags_from_instances: Product instances are not supported; use Product.tags.remove() or a propagation-aware helper" + raise ValueError(msg) + + try: + tag_field = model_class._meta.get_field(tag_field_name) + except Exception: + msg = f"Model {model_class.__name__} does not have field '{tag_field_name}'" + raise ValueError(msg) + + if not hasattr(tag_field, "tag_options"): + msg = f"Field '{tag_field_name}' is not a TagField" + raise ValueError(msg) + + tag_model = tag_field.related_model + through_model = tag_field.remote_field.through + + # Normalize input into a list of tag names (mirrors bulk_add_tags_to_instances). + tag_names: list[str] = [] + try: + if isinstance(tag_or_tags, str): + space_delimiter = getattr(tag_field, "tag_options", None).space_delimiter if hasattr(tag_field, "tag_options") else False + tag_names = parse_tags(tag_or_tags, space_delimiter=space_delimiter) + elif isinstance(tag_or_tags, Iterable): + tag_names = [getattr(t, "name", str(t)) for t in tag_or_tags] + else: + tag_names = [str(tag_or_tags)] + except Exception: + tag_names = [str(tag_or_tags)] + + # Resolve through-model FK names dynamically (no hard-coding). + through_fields = {f.name: f for f in through_model._meta.fields} + source_field_name = None + target_field_name = None + for field_name, field in through_fields.items(): + if hasattr(field, "remote_field") and field.remote_field: + if field.remote_field.model == model_class: + source_field_name = field_name + elif field.remote_field.model == tag_model: + target_field_name = field_name + + total_removed = 0 + + for single_tag_name in tag_names: + if not single_tag_name: + continue + + # Resolve the tag — skip silently if it doesn't exist (nothing to remove). + if tag_field.tag_options.case_sensitive: + tag = tag_model.objects.filter(name=single_tag_name).first() + else: + tag = tag_model.objects.filter(name__iexact=single_tag_name).first() + if tag is None: + continue + + for i in range(0, len(instances), batch_size): + batch_instances = instances[i:i + batch_size] + batch_ids = [instance.pk for instance in batch_instances] + + with transaction.atomic(): + # One DELETE per tag-batch. Returns the deleted-row count. + deleted_count, _ = through_model.objects.filter( + **{target_field_name: tag.pk}, + **{f"{source_field_name}__in": batch_ids}, + ).delete() + + if deleted_count: + total_removed += deleted_count + # Decrement the Tagulous-maintained count to avoid drift. + tag_model.objects.filter(pk=tag.pk).update( + count=models.F("count") - deleted_count, + ) + + # Invalidate prefetch caches so callers see the new state. + for instance in batch_instances: + prefetch_cache = getattr(instance, "_prefetched_objects_cache", None) + if prefetch_cache is not None: + prefetch_cache.pop(tag_field_name, None) + + return total_removed + + def bulk_add_tag_mapping( tag_to_instances: dict[str, list], tag_field_name: str = "tags", @@ -410,4 +522,4 @@ def bulk_remove_all_tags(model_class, instance_ids_qs): ) -__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags"] +__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags", "bulk_remove_tags_from_instances"] diff --git a/dojo/tags_signals.py b/dojo/tags_signals.py index 0fea7ae8ad5..6fe142f2e55 100644 --- a/dojo/tags_signals.py +++ b/dojo/tags_signals.py @@ -59,6 +59,12 @@ def inherit_linked_instance_tags(instance: LocationFindingReference | LocationPr @receiver(signals.post_save, sender=Finding) @receiver(signals.post_save, sender=Location) def inherit_tags_on_instance(sender, instance, created, **kwargs): + # Only inherit on creation. The previous behavior fired on every save + # (create OR update), repeatedly re-applying inherited tags to children + # whose tag state had not changed. Sticky enforcement on user-driven + # tag edits is handled by `make_inherited_tags_sticky` (m2m_changed). + if not created: + return inherit_instance_tags(instance) diff --git a/unittests/test_tag_inheritance_perf.py b/unittests/test_tag_inheritance_perf.py index 450a3d2ab89..8e9f9a5f792 100644 --- a/unittests/test_tag_inheritance_perf.py +++ b/unittests/test_tag_inheritance_perf.py @@ -357,10 +357,12 @@ def test_baseline_product_tag_remove_propagates_to_100_locations_v3(self): # the appropriate mode so all variants execute in a single suite run. # Findings-only scenarios. - EXPECTED_PRODUCT_TAG_ADD_100_V2 = 4758 - EXPECTED_PRODUCT_TAG_ADD_100_V3 = 4759 - EXPECTED_PRODUCT_TAG_REMOVE_100_V2 = 4540 - EXPECTED_PRODUCT_TAG_REMOVE_100_V3 = 4541 + # Pre-Phase-A V2: 4758 add, 4540 remove. V3: 4759/4541. + # Phase A bulk-propagate drops these dramatically. + EXPECTED_PRODUCT_TAG_ADD_100_V2 = 91 + EXPECTED_PRODUCT_TAG_ADD_100_V3 = 91 + EXPECTED_PRODUCT_TAG_REMOVE_100_V2 = 53 + EXPECTED_PRODUCT_TAG_REMOVE_100_V3 = 53 EXPECTED_CREATE_ONE_FINDING_V2 = 64 EXPECTED_CREATE_ONE_FINDING_V3 = 64 @@ -372,13 +374,13 @@ def test_baseline_product_tag_remove_propagates_to_100_locations_v3(self): EXPECTED_FINDING_REMOVE_INHERITED_V2 = 44 EXPECTED_FINDING_REMOVE_INHERITED_V3 = 44 - # V2 endpoint paths (Endpoints have no V3 counterpart in this class). - EXPECTED_PRODUCT_TAG_ADD_100_ENDPOINTS = 3958 - EXPECTED_PRODUCT_TAG_REMOVE_100_ENDPOINTS = 3740 + # V2 endpoint paths. Pre-Phase-A: 3958 add, 3740 remove. + EXPECTED_PRODUCT_TAG_ADD_100_ENDPOINTS = 91 + EXPECTED_PRODUCT_TAG_REMOVE_100_ENDPOINTS = 53 - # V3 location paths (LocationManager has no V2 counterpart in this class). - EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 4531 - EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 4307 + # V3 location paths. Pre-Phase-A: 4532 add, 4307 remove. + EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 316 + EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 266 @override_settings( @@ -488,7 +490,10 @@ def test_baseline_zap_scan_reimport_no_change_v3(self): # Pinned baselines per mode. Each test forces its own V3_FEATURE_LOCATIONS # via @override_settings so all four import paths run in a single suite # invocation regardless of the ambient `DD_V3_FEATURE_LOCATIONS` env var. - EXPECTED_ZAP_IMPORT_V2 = 1461 - EXPECTED_ZAP_IMPORT_V3 = 1319 - EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 77 - EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 95 + # Phase A nudges these slightly downward (post_save gated on created=True + # avoids re-running inheritance on no-op finding updates during reimport). + # Pre-Phase-A: 1461/1319 import, 77/95 reimport. + EXPECTED_ZAP_IMPORT_V2 = 1385 + EXPECTED_ZAP_IMPORT_V3 = 1243 + EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 69 + EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 87