Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
126 changes: 104 additions & 22 deletions dojo/product/helpers.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import contextlib
import logging
from collections import defaultdict

from django.conf import settings
from django.db.models import Q

from dojo.celery import app
from dojo.location.models import Location
from dojo.models import Endpoint, Engagement, Finding, Product, Test
from dojo.tag_utils import bulk_add_tag_mapping, bulk_remove_tags_from_instances

logger = logging.getLogger(__name__)

Expand All @@ -19,35 +21,115 @@ def propagate_tags_on_product(product_id, *args, **kwargs):


def propagate_tags_on_product_sync(product):
# enagagements
"""
Bulk-apply Product tag changes to all children using through-table SQL.

Replaces the previous per-row `.save()` loop. For every child model owned
by the product (Engagement, Test, Finding, plus Endpoint or Location
depending on the V3_FEATURE_LOCATIONS flag), reads the existing
`inherited_tags` per child in one query, computes the diff against the
Product's current tags, and applies adds/removes via the bulk tag
helpers. Both `tags` and `inherited_tags` fields are kept in sync.
"""
target_names = {tag.name for tag in product.tags.all()}

logger.debug("Propagating tags from %s to all engagements", product)
propagate_tags_on_object_list(Engagement.objects.filter(product=product))
# tests
_sync_inheritance_for_qs(
Engagement.objects.filter(product=product),
target_names_per_child=lambda _child: target_names,
)
logger.debug("Propagating tags from %s to all tests", product)
propagate_tags_on_object_list(Test.objects.filter(engagement__product=product))
# findings
_sync_inheritance_for_qs(
Test.objects.filter(engagement__product=product),
target_names_per_child=lambda _child: target_names,
)
logger.debug("Propagating tags from %s to all findings", product)
propagate_tags_on_object_list(Finding.objects.filter(test__engagement__product=product))
_sync_inheritance_for_qs(
Finding.objects.filter(test__engagement__product=product),
target_names_per_child=lambda _child: target_names,
)
if settings.V3_FEATURE_LOCATIONS:
# Locations
logger.debug("Propagating tags from %s to all locations", product)
propagate_tags_on_object_list(
Location.objects.filter(
# Locations linked directly to a product via LocationProductReference
Q(products__product=product)
# Locations linked indirectly to a product via LocationFindingReference
| Q(findings__finding__test__engagement__product=product),
).distinct(),
location_qs = Location.objects.filter(
Q(products__product=product)
| Q(findings__finding__test__engagement__product=product),
).distinct()
# Locations can be linked to multiple products, so the inherited target
# is the union of every related product's tags. Compute per-location.
_sync_inheritance_for_qs(
location_qs,
target_names_per_child=_location_target_names,
)
else:
# TODO: Delete this after the move to Locations
# endpoints
logger.debug("Propagating tags from %s to all endpoints", product)
propagate_tags_on_object_list(Endpoint.objects.filter(product=product))
_sync_inheritance_for_qs(
Endpoint.objects.filter(product=product),
target_names_per_child=lambda _child: target_names,
)


def _location_target_names(location):
names: set[str] = set()
for related_product in location.all_related_products():
if related_product is None:
continue
names.update(tag.name for tag in related_product.tags.all())
return names


def _sync_inheritance_for_qs(queryset, *, target_names_per_child):
"""
Sync inherited_tags + tags for every child in `queryset` to its target tag set.

target_names_per_child: callable(child) -> set[str].

Issues bulk SQL: one through-table read for current inherited_tags, then
bulk add/remove on `tags` and `inherited_tags` fields.
"""
children = list(queryset)
if not children:
return

model_class = type(children[0])
inherited_field = model_class._meta.get_field("inherited_tags")
inherited_through = inherited_field.remote_field.through
inherited_tag_model = inherited_field.related_model

# Resolve through-table FK column for the source side.
source_field_name = None
for field in inherited_through._meta.fields:
if hasattr(field, "remote_field") and field.remote_field and field.remote_field.model == model_class:
source_field_name = field.name
break

child_ids = [c.pk for c in children]
# One query: pull every (child_id, tag_name) pair from the inherited_tags through table.
existing_pairs = inherited_through.objects.filter(
**{f"{source_field_name}__in": child_ids},
).values_list(source_field_name, f"{inherited_tag_model._meta.model_name}__name")

old_inherited_by_child: dict[int, set[str]] = defaultdict(set)
for child_id, tag_name in existing_pairs:
old_inherited_by_child[child_id].add(tag_name)

# Compute per-child diff and bucket by tag name.
add_map: dict[str, list] = defaultdict(list)
remove_map: dict[str, list] = defaultdict(list)
for child in children:
target = target_names_per_child(child)
old = old_inherited_by_child.get(child.pk, set())
for name in target - old:
add_map[name].append(child)
for name in old - target:
remove_map[name].append(child)

# Apply adds. Both `tags` and `inherited_tags` get the same set of new
# inherited names — `_manage_inherited_tags` did the same.
if add_map:
bulk_add_tag_mapping(add_map, tag_field_name="inherited_tags")
bulk_add_tag_mapping(add_map, tag_field_name="tags")

def propagate_tags_on_object_list(object_list):
for obj in object_list:
if obj and obj.id is not None:
logger.debug(f"\tPropagating tags to {type(obj)} - {obj}")
obj.save()
# Apply removes.
for name, instances in remove_map.items():
bulk_remove_tags_from_instances(name, instances, tag_field_name="inherited_tags")
bulk_remove_tags_from_instances(name, instances, tag_field_name="tags")
114 changes: 113 additions & 1 deletion dojo/tag_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,118 @@ def bulk_add_tags_to_instances(tag_or_tags, instances, tag_field_name: str = "ta
return total_created


def bulk_remove_tags_from_instances(tag_or_tags, instances, tag_field_name: str = "tags", batch_size: int | None = None) -> int:
"""
Efficiently remove tag(s) from many model instances.

Symmetric to ``bulk_add_tags_to_instances``:

- tag_or_tags: a single string, an iterable of strings or tag objects, or a Tagulous edit string.
- instances: QuerySet or list of model instances of the same class.
- tag_field_name: name of the TagField on the model (default: ``"tags"``).
- Decrements ``tag.count`` for every removed (instance, tag) pair.
- Deletes through-model rows in one DELETE per tag (or batched).
- Clears Django prefetch caches on the input instances so subsequent access reloads from DB.

Returns the total number of relationships removed across all provided tags.
Tags that do not exist or are not currently associated with any instance are silently skipped.
"""
if batch_size is None:
batch_size = getattr(settings, "TAG_BULK_ADD_BATCH_SIZE", 1000)

if hasattr(instances, "model"):
instances = list(instances)

if not instances:
return 0

model_class = instances[0].__class__

# Mirror the Product safety check from bulk_add_tags_to_instances. Removing
# tags from a Product would normally trigger inheritance propagation via
# m2m_changed signals; this helper bypasses signals, so disallow it.
if model_class is Product:
msg = "bulk_remove_tags_from_instances: Product instances are not supported; use Product.tags.remove() or a propagation-aware helper"
raise ValueError(msg)

try:
tag_field = model_class._meta.get_field(tag_field_name)
except Exception:
msg = f"Model {model_class.__name__} does not have field '{tag_field_name}'"
raise ValueError(msg)

if not hasattr(tag_field, "tag_options"):
msg = f"Field '{tag_field_name}' is not a TagField"
raise ValueError(msg)

tag_model = tag_field.related_model
through_model = tag_field.remote_field.through

# Normalize input into a list of tag names (mirrors bulk_add_tags_to_instances).
tag_names: list[str] = []
try:
if isinstance(tag_or_tags, str):
space_delimiter = getattr(tag_field, "tag_options", None).space_delimiter if hasattr(tag_field, "tag_options") else False
tag_names = parse_tags(tag_or_tags, space_delimiter=space_delimiter)
elif isinstance(tag_or_tags, Iterable):
tag_names = [getattr(t, "name", str(t)) for t in tag_or_tags]
else:
tag_names = [str(tag_or_tags)]
except Exception:
tag_names = [str(tag_or_tags)]

# Resolve through-model FK names dynamically (no hard-coding).
through_fields = {f.name: f for f in through_model._meta.fields}
source_field_name = None
target_field_name = None
for field_name, field in through_fields.items():
if hasattr(field, "remote_field") and field.remote_field:
if field.remote_field.model == model_class:
source_field_name = field_name
elif field.remote_field.model == tag_model:
target_field_name = field_name

total_removed = 0

for single_tag_name in tag_names:
if not single_tag_name:
continue

# Resolve the tag — skip silently if it doesn't exist (nothing to remove).
if tag_field.tag_options.case_sensitive:
tag = tag_model.objects.filter(name=single_tag_name).first()
else:
tag = tag_model.objects.filter(name__iexact=single_tag_name).first()
if tag is None:
continue

for i in range(0, len(instances), batch_size):
batch_instances = instances[i:i + batch_size]
batch_ids = [instance.pk for instance in batch_instances]

with transaction.atomic():
# One DELETE per tag-batch. Returns the deleted-row count.
deleted_count, _ = through_model.objects.filter(
**{target_field_name: tag.pk},
**{f"{source_field_name}__in": batch_ids},
).delete()

if deleted_count:
total_removed += deleted_count
# Decrement the Tagulous-maintained count to avoid drift.
tag_model.objects.filter(pk=tag.pk).update(
count=models.F("count") - deleted_count,
)

# Invalidate prefetch caches so callers see the new state.
for instance in batch_instances:
prefetch_cache = getattr(instance, "_prefetched_objects_cache", None)
if prefetch_cache is not None:
prefetch_cache.pop(tag_field_name, None)

return total_removed


def bulk_add_tag_mapping(
tag_to_instances: dict[str, list],
tag_field_name: str = "tags",
Expand Down Expand Up @@ -410,4 +522,4 @@ def bulk_remove_all_tags(model_class, instance_ids_qs):
)


__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags"]
__all__ = ["bulk_add_tag_mapping", "bulk_add_tags_to_instances", "bulk_apply_parser_tags", "bulk_remove_all_tags", "bulk_remove_tags_from_instances"]
6 changes: 6 additions & 0 deletions dojo/tags_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,12 @@ def inherit_linked_instance_tags(instance: LocationFindingReference | LocationPr
@receiver(signals.post_save, sender=Finding)
@receiver(signals.post_save, sender=Location)
def inherit_tags_on_instance(sender, instance, created, **kwargs):
# Only inherit on creation. The previous behavior fired on every save
# (create OR update), repeatedly re-applying inherited tags to children
# whose tag state had not changed. Sticky enforcement on user-driven
# tag edits is handled by `make_inherited_tags_sticky` (m2m_changed).
if not created:
return
inherit_instance_tags(instance)


Expand Down
33 changes: 19 additions & 14 deletions unittests/test_tag_inheritance_perf.py
Original file line number Diff line number Diff line change
Expand Up @@ -357,10 +357,12 @@ def test_baseline_product_tag_remove_propagates_to_100_locations_v3(self):
# the appropriate mode so all variants execute in a single suite run.

# Findings-only scenarios.
EXPECTED_PRODUCT_TAG_ADD_100_V2 = 4758
EXPECTED_PRODUCT_TAG_ADD_100_V3 = 4759
EXPECTED_PRODUCT_TAG_REMOVE_100_V2 = 4540
EXPECTED_PRODUCT_TAG_REMOVE_100_V3 = 4541
# Pre-Phase-A V2: 4758 add, 4540 remove. V3: 4759/4541.
# Phase A bulk-propagate drops these dramatically.
EXPECTED_PRODUCT_TAG_ADD_100_V2 = 91
EXPECTED_PRODUCT_TAG_ADD_100_V3 = 91
EXPECTED_PRODUCT_TAG_REMOVE_100_V2 = 53
EXPECTED_PRODUCT_TAG_REMOVE_100_V3 = 53

EXPECTED_CREATE_ONE_FINDING_V2 = 64
EXPECTED_CREATE_ONE_FINDING_V3 = 64
Expand All @@ -372,13 +374,13 @@ def test_baseline_product_tag_remove_propagates_to_100_locations_v3(self):
EXPECTED_FINDING_REMOVE_INHERITED_V2 = 44
EXPECTED_FINDING_REMOVE_INHERITED_V3 = 44

# V2 endpoint paths (Endpoints have no V3 counterpart in this class).
EXPECTED_PRODUCT_TAG_ADD_100_ENDPOINTS = 3958
EXPECTED_PRODUCT_TAG_REMOVE_100_ENDPOINTS = 3740
# V2 endpoint paths. Pre-Phase-A: 3958 add, 3740 remove.
EXPECTED_PRODUCT_TAG_ADD_100_ENDPOINTS = 91
EXPECTED_PRODUCT_TAG_REMOVE_100_ENDPOINTS = 53

# V3 location paths (LocationManager has no V2 counterpart in this class).
EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 4531
EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 4307
# V3 location paths. Pre-Phase-A: 4532 add, 4307 remove.
EXPECTED_PRODUCT_TAG_ADD_100_LOCATIONS = 316
EXPECTED_PRODUCT_TAG_REMOVE_100_LOCATIONS = 266


@override_settings(
Expand Down Expand Up @@ -488,7 +490,10 @@ def test_baseline_zap_scan_reimport_no_change_v3(self):
# Pinned baselines per mode. Each test forces its own V3_FEATURE_LOCATIONS
# via @override_settings so all four import paths run in a single suite
# invocation regardless of the ambient `DD_V3_FEATURE_LOCATIONS` env var.
EXPECTED_ZAP_IMPORT_V2 = 1461
EXPECTED_ZAP_IMPORT_V3 = 1319
EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 77
EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 95
# Phase A nudges these slightly downward (post_save gated on created=True
# avoids re-running inheritance on no-op finding updates during reimport).
# Pre-Phase-A: 1461/1319 import, 77/95 reimport.
EXPECTED_ZAP_IMPORT_V2 = 1385
EXPECTED_ZAP_IMPORT_V3 = 1243
EXPECTED_ZAP_REIMPORT_NO_CHANGE_V2 = 69
EXPECTED_ZAP_REIMPORT_NO_CHANGE_V3 = 87
Loading