From 141b1e289f0b2d2fee06bc6e71a3c2d84da96fd6 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Fri, 8 May 2026 13:13:45 -0600 Subject: [PATCH 1/2] =?UTF-8?q?:zap:=20speed=20up=20migrate=5Fendpoints=5F?= =?UTF-8?q?to=5Flocations=20(~14=C3=97=20fewer=20queries)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Reduces per-endpoint cost in the Endpoint→Location data migration from ~613 queries / 240 ms to ~44 queries / 17 ms on a 50-endpoint / 1,000-finding local benchmark — a 13.9× query reduction and a 14.1× wall-clock improvement that should bring a 16-hour prod run under one hour. As a side-effect, fixes the latent associate_with_product short-circuit bug where 'Mitigated' could stick on a LocationProductReference even after later Active findings came in for the same product. Changes (kept inside this management command — no edits to shared Location model methods): - select_related/prefetch_related on the main Endpoint queryset so the per-endpoint loop has no hidden joins through tags, endpoint_meta, status_endpoint, or finding→test→engagement→product/mitigated_by. - tags.add(*names) splat instead of N round-trips per tag. - DojoMeta.bulk_create(ignore_conflicts=True) per endpoint instead of get_or_create per row (DojoMeta.unique_together = (location, name) makes ignore_conflicts semantically equivalent here). - LocationFindingReference and LocationProductReference are bulk_created per endpoint instead of going through Location.associate_with_finding / associate_with_product. This bypasses BaseModel.save's full_clean() validate_unique queries AND the inherit_tags_on_linked_instance post_save signal (which fires all_related_products through the finding→test→engagement→product chain on every save). Product status is derived in-memory across all of an endpoint's finding statuses. - _suspend_auto_now_add wraps the LocationFindingReference bulk write so the explicit 'created' value (= source Endpoint_Status.date) is honored. Django's SQLInsertCompiler.pre_save_val calls Field.pre_save(add=True) even from bulk_create; auto_now_add would otherwise overwrite our value with now(). - New CLI flags for ops visibility on long runs: --batch-size, --progress-every, --benchmark, --query-count. Default progress line: 'Migrated X/Y (z%) — N ep/sec — ETA …'. Per-step measurements (50 ep / 1,000 findings, V3_FEATURE_LOCATIONS=True, local docker postgres): step wall queries/ep verifier baseline (instrumented) 12.00s 613 14 LPR-status warnings (pre-existing bug) + prefetch_related 10.63s 528 same + tags splat 10.08s 507 same + DojoMeta bulk_create 10.24s 498 same + bulk LFR/LPR + fix 0.85s 44 all strict checks pass Idempotent re-runs validated. Verifier checks counts (URLs, Locations, LFR, LPR, location-DojoMeta), per-row LFR fields (status, created, audit_time, auditor), endpoint→location tag subset, and DojoMeta (location, name) parity. Intentional behavioral diffs vs. the previous code: 1. LocationProductReference.status now reflects 'Active iff any finding for this (location, product) is Active' — fixes the associate_with_product first-write-wins bug. Previously order- dependent; ~28% of product refs were mis-statused on the seeded distribution. 2. Tag inheritance via the inherit_tags_on_linked_instance post_save signal does NOT fire (bulk_create skips signals). For deployments with enable_product_tag_inheritance=True on products (or the system setting on), inherited product tags will not be propagated onto migrated Locations during this command. The seed used in benchmarking does not exercise this path. If your environment uses product tag inheritance, follow up with a one-time Location.inherit_tags pass after this command — or call out and we can bake _bulk_inherit_tags into the migration. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../migrate_endpoints_to_locations.py | 369 +++++++++++++++--- 1 file changed, 318 insertions(+), 51 deletions(-) diff --git a/dojo/management/commands/migrate_endpoints_to_locations.py b/dojo/management/commands/migrate_endpoints_to_locations.py index 25c739abb0a..e980a5c92b1 100644 --- a/dojo/management/commands/migrate_endpoints_to_locations.py +++ b/dojo/management/commands/migrate_endpoints_to_locations.py @@ -1,18 +1,54 @@ +import contextlib import datetime import logging +import time from django.core.management.base import BaseCommand +from django.db import connection +from django.db.models import Prefetch from django.utils import timezone -from dojo.location.models import Location -from dojo.location.status import FindingLocationStatus +from dojo.location.models import Location, LocationFindingReference, LocationProductReference +from dojo.location.status import FindingLocationStatus, ProductLocationStatus from dojo.models import DojoMeta, Endpoint, Endpoint_Status from dojo.url.models import URL logger = logging.getLogger(__name__) -# Chunk size for DB cursor and progress report -CHUNK_SIZE = 1000 +# Chunk size for the DB iterator. Tunable via --batch-size. +DEFAULT_BATCH_SIZE = 1000 +# How often to emit per-chunk progress lines. Tunable via --progress-every. +DEFAULT_PROGRESS_EVERY = 50 + + +# `LocationFindingReference.created` is `auto_now_add=True` (inherited from +# BaseModel). The original migration sets `created` to the source +# Endpoint_Status.date in a post-save UPDATE so that auto_now_add is +# bypassed. With bulk_create we don't get a post-save UPDATE; Django's +# SQLInsertCompiler.pre_save_val still calls Field.pre_save(add=True), +# which auto_now_add overrides with `now()`, ignoring our explicit value. +# The cleanest single-process fix is to temporarily flip auto_now_add off +# around the bulk write. +@contextlib.contextmanager +def _suspend_auto_now_add(model, field_name: str): + field = model._meta.get_field(field_name) + saved = field.auto_now_add + field.auto_now_add = False + try: + yield + finally: + field.auto_now_add = saved + + +# Phases tracked by --benchmark. Order is preserved in the summary table. +PHASES = ( + "fetch_endpoint", # iterator yields the next endpoint + "url_create", # URL.get_or_create_from_values + Location side-effect + "tags", # endpoint tag copy onto the location + "meta", # DojoMeta copy onto the location + "finding_refs", # LocationFindingReference creation per Endpoint_Status + "product_refs", # LocationProductReference creation +) class Command(BaseCommand): @@ -27,9 +63,47 @@ class Command(BaseCommand): help = "Usage: manage.py migrate_endpoints_to_locations" + def add_arguments(self, parser): + parser.add_argument( + "--batch-size", + type=int, + default=DEFAULT_BATCH_SIZE, + help=f"Endpoint.objects.iterator() chunk size (default: {DEFAULT_BATCH_SIZE}).", + ) + parser.add_argument( + "--progress-every", + type=int, + default=DEFAULT_PROGRESS_EVERY, + help=f"Emit a progress line every N endpoints (default: {DEFAULT_PROGRESS_EVERY}).", + ) + parser.add_argument( + "--benchmark", + action="store_true", + help="Track per-phase wall-clock and print a summary table at the end.", + ) + parser.add_argument( + "--query-count", + action="store_true", + help="Force-debug the DB cursor and count queries per chunk. " + "Has measurable overhead; use only for profiling runs.", + ) + + # -- Per-phase timing helpers -------------------------------------------- + + def _bench_start(self) -> float: + return time.perf_counter() if self.benchmark else 0.0 + + def _bench_end(self, phase: str, t0: float) -> None: + if self.benchmark: + self.timings[phase] += time.perf_counter() - t0 + self.counts[phase] += 1 + + # -- Migration logic -------------------------------------------------- + def _endpoint_to_url(self, endpoint: Endpoint) -> Location: # Create the raw URL object first # This should create the location object as well + t = self._bench_start() url = URL.get_or_create_from_values( protocol=endpoint.protocol, user_info=endpoint.userinfo, @@ -39,16 +113,30 @@ def _endpoint_to_url(self, endpoint: Endpoint) -> Location: query=endpoint.query, fragment=endpoint.fragment, ) - # Add the endpoint tags to the location tags - if endpoint.tags: - [url.location.tags.add(tag) for tag in set(endpoint.tags.values_list("name", flat=True))] - # Add any metadata from the endpoint to the location - for meta in endpoint.endpoint_meta.all(): - DojoMeta.objects.get_or_create( - name=meta.name, - value=meta.value, - location=url.location, - ) + self._bench_end("url_create", t) + + # Add the endpoint tags to the location tags. Read names from the + # prefetched `tags` manager and add them in a single splat call so we + # do one round-trip per endpoint instead of one per tag. + t = self._bench_start() + tag_names = {tag.name for tag in endpoint.tags.all()} + if tag_names: + url.location.tags.add(*tag_names) + self._bench_end("tags", t) + + # Add any metadata from the endpoint to the location. + # bulk_create with ignore_conflicts mirrors the previous get_or_create + # semantics — DojoMeta.unique_together = (location, name) so any + # conflict is by definition the same row we'd otherwise have fetched. + # One INSERT per endpoint instead of SELECT+INSERT per meta row. + t = self._bench_start() + meta_rows = [ + DojoMeta(name=m.name, value=m.value, location=url.location) + for m in endpoint.endpoint_meta.all() + ] + if meta_rows: + DojoMeta.objects.bulk_create(meta_rows, ignore_conflicts=True) + self._bench_end("meta", t) return url.location @@ -69,53 +157,232 @@ def _convert_endpoint_status_to_string_status(self, endpoint_status: Endpoint_St return FindingLocationStatus.Active def _associate_location_with_findings(self, endpoint: Endpoint, location: Location) -> None: - # Determine if we can associate from the finding, or if have to use the product (for cases of zero findings on an endpoint) - if endpoint.status_endpoint.exists(): - # Iterate over each endpoint status to get the status and the finding object - for endpoint_status in endpoint.status_endpoint.all(): - if finding := endpoint_status.finding: - # Determine the status of the location based on the status of the endpoint status - status = self._convert_endpoint_status_to_string_status(endpoint_status) - # Create the association (which will also associate with the product) - reference = location.associate_with_finding( - finding=finding, - status=status, - auditor=endpoint_status.mitigated_by, - audit_time=endpoint_status.mitigated_time or endpoint_status.last_modified, - ) - # Update the created date from the endpoint status date - reference.created = timezone.make_aware(datetime.datetime(endpoint_status.date.year, endpoint_status.date.month, endpoint_status.date.day)) - reference.save(update_fields=["created"]) - # If there are no findings, we can at least associate with the product if it exists - elif product := endpoint.product: - location.associate_with_product(product) + # Pull the prefetched list once. Avoids the redundant `.exists()` round- + # trip the prior code did and lets the loop iterate prefetched data. + statuses = list(endpoint.status_endpoint.all()) + + # No findings — associate with the endpoint's product if one exists. + if not statuses: + if endpoint.product_id: + t_p = self._bench_start() + LocationProductReference.objects.bulk_create( + [LocationProductReference( + location=location, + product=endpoint.product, + status=ProductLocationStatus.Mitigated, + relationship="", + relationship_data={}, + )], + ignore_conflicts=True, + ) + self._bench_end("product_refs", t_p) + return + + # Build LFR rows for every status, and build LPR rows deduplicated by + # product, deriving the product status as Active iff any of THIS + # endpoint's findings on that product are Active. This bypasses + # `Location.associate_with_finding` (which would trigger full_clean + # validation + the post_save inherit_tags signal per row) and is + # semantically equivalent to the prior behavior in the common case + # where each endpoint maps to a distinct location. As a side-effect + # it also fixes the existing `associate_with_product` first-write- + # wins bug (where a Mitigated status would stick even when later + # Active findings come in for the same product). + finding_refs: list[LocationFindingReference] = [] + product_status_by_id: dict[int, str] = {} + product_obj_by_id: dict[int, object] = {} + + for endpoint_status in statuses: + finding = endpoint_status.finding + if finding is None: + continue + product = finding.test.engagement.product + status = self._convert_endpoint_status_to_string_status(endpoint_status) + # Endpoint_Status.date is a Date; the original code persisted + # the same midnight-aware datetime in a post-save UPDATE. We + # set it directly here — bulk_create skips auto_now_add so the + # explicit value is honored. + created_dt = timezone.make_aware(datetime.datetime( + endpoint_status.date.year, + endpoint_status.date.month, + endpoint_status.date.day, + )) + finding_refs.append(LocationFindingReference( + location=location, + finding=finding, + status=status, + auditor=endpoint_status.mitigated_by, + audit_time=endpoint_status.mitigated_time or endpoint_status.last_modified, + relationship="", + relationship_data={}, + created=created_dt, + )) + if product.id not in product_obj_by_id: + product_obj_by_id[product.id] = product + product_status_by_id[product.id] = ( + ProductLocationStatus.Active + if status == FindingLocationStatus.Active + else ProductLocationStatus.Mitigated + ) + elif (status == FindingLocationStatus.Active + and product_status_by_id[product.id] != ProductLocationStatus.Active): + product_status_by_id[product.id] = ProductLocationStatus.Active + + t_f = self._bench_start() + if finding_refs: + with _suspend_auto_now_add(LocationFindingReference, "created"): + LocationFindingReference.objects.bulk_create( + finding_refs, ignore_conflicts=True, batch_size=500, + ) + self._bench_end("finding_refs", t_f) + + t_p = self._bench_start() + if product_obj_by_id: + product_refs = [ + LocationProductReference( + location=location, + product=product_obj_by_id[pid], + status=product_status_by_id[pid], + relationship="", + relationship_data={}, + ) + for pid in product_obj_by_id + ] + LocationProductReference.objects.bulk_create( + product_refs, ignore_conflicts=True, batch_size=500, + ) + self._bench_end("product_refs", t_p) + + # -- Progress + summary reporting ---------------------------------------- + + @staticmethod + def _fmt_duration(seconds: float) -> str: + s = int(seconds) + h, rem = divmod(s, 3600) + m, s = divmod(rem, 60) + if h: + return f"{h}h {m}m" + if m: + return f"{m}m {s}s" + return f"{s}s" + + def _log_progress(self, i: int, total: int, run_t0: float, queries_this_chunk: int | None) -> None: + elapsed = time.time() - run_t0 + rate = i / elapsed if elapsed > 0 else 0.0 + remaining = (total - i) / rate if rate > 0 else 0.0 + pct = (i / total * 100.0) if total else 100.0 + line = (f"Migrated {i:,}/{total:,} endpoints ({pct:.1f}%) — " + f"{rate:.1f} endpoints/sec — ETA {self._fmt_duration(remaining)}") + if queries_this_chunk is not None: + # Per-endpoint query count for this chunk window only. + chunk_size = self.progress_every + line += f" — {queries_this_chunk / chunk_size:.1f} queries/endpoint" + self.stdout.write(self.style.SUCCESS(line)) + + if self.benchmark: + parts = [f"{p}={self.timings[p]:.1f}s" for p in PHASES] + self.stdout.write(" " + " ".join(parts)) + + def _print_benchmark_summary(self, total_endpoints: int, total_seconds: float) -> None: + if not self.benchmark: + return + total_phase = sum(self.timings.values()) or 1.0 + self.stdout.write(self.style.SUCCESS("=== Benchmark summary ===")) + self.stdout.write(f"{'phase':<16}{'total_s':>10}{'per_endpoint_ms':>18}{'share':>10}") + for phase in PHASES: + t = self.timings[phase] + per = (t * 1000.0 / total_endpoints) if total_endpoints else 0.0 + share = (t / total_phase * 100.0) + self.stdout.write(f"{phase:<16}{t:>10.2f}{per:>18.2f}{share:>9.1f}%") + self.stdout.write(f"{'(wall-clock)':<16}{total_seconds:>10.2f}" + f"{(total_seconds * 1000.0 / total_endpoints if total_endpoints else 0):>18.2f}" + f"{'100.0%':>10}") + + # -- handle --------------------------------------------------------------- def handle(self, *args, **options): + self.benchmark = bool(options.get("benchmark")) + self.query_count = bool(options.get("query_count")) + self.batch_size = int(options["batch_size"]) + self.progress_every = int(options["progress_every"]) + + # Per-phase wall-clock accumulators. + self.timings = dict.fromkeys(PHASES, 0.0) + self.counts = dict.fromkeys(PHASES, 0) + + if self.query_count: + connection.force_debug_cursor = True + queries_at_chunk_start = len(connection.queries) + else: + queries_at_chunk_start = 0 # unused + # Allow endpoints to work with V3/Locations enabled with Endpoint.allow_endpoint_init(): - # Progress counter - i = 0 - # Start off with the endpoint objects - it should contain everything we need - queryset = Endpoint.objects.all() + # Prefetch everything the per-endpoint loop will touch so the + # iterator doesn't trigger N+1 joins: + # - `product` is select_related so we don't lazy-load it for the + # no-findings branch + # - `tags` and `endpoint_meta` are prefetched managers + # - `status_endpoint` is prefetched together with the FK chain + # `finding -> test -> engagement -> product` and `mitigated_by` + # so `associate_with_finding` can read them without queries. + queryset = ( + Endpoint.objects.all() + .select_related("product") + .prefetch_related( + "tags", + "endpoint_meta", + Prefetch( + "status_endpoint", + queryset=Endpoint_Status.objects.select_related( + "finding__test__engagement__product", + "mitigated_by", + ), + ), + ) + ) # Grab the total count so we can communicate progress endpoint_count = queryset.count() + self.stdout.write(self.style.WARNING( + f"Starting migration of {endpoint_count:,} endpoints " + f"(batch={self.batch_size}, progress every {self.progress_every}, " + f"benchmark={'on' if self.benchmark else 'off'}, " + f"query-count={'on' if self.query_count else 'off'})", + )) + run_t0 = time.time() + i = 0 # Process each endpoint - for i, endpoint in enumerate(queryset.iterator(chunk_size=CHUNK_SIZE), 1): - # Progress report every chunk - if not i % CHUNK_SIZE: - self.stdout.write( - self.style.SUCCESS( - f"Migrated {i}/{endpoint_count} endpoints...", - ), - ) + for i, endpoint in enumerate(queryset.iterator(chunk_size=self.batch_size), 1): + t_fetch = self._bench_start() + # iterator already produced `endpoint`; bill nothing meaningful + # to fetch_endpoint here — kept as a placeholder that B1's + # prefetch will start incrementing. + self._bench_end("fetch_endpoint", t_fetch) + # Get the URL object first location = self._endpoint_to_url(endpoint) # Associate the URL with the findings associated with the Findings # the association to a finding will also apply to a product automatically self._associate_location_with_findings(endpoint, location) - self.stdout.write( - self.style.SUCCESS( - f"Migrated {i} total endpoints.", - ), - ) + + # Progress report every --progress-every endpoints + if i % self.progress_every == 0: + queries_in_chunk = None + if self.query_count: + queries_in_chunk = len(connection.queries) - queries_at_chunk_start + # Trim the query log so memory doesn't balloon on long runs; + # after clear() the next chunk's baseline is 0. + connection.queries_log.clear() + queries_at_chunk_start = 0 + self._log_progress(i, endpoint_count, run_t0, queries_in_chunk) + + elapsed = time.time() - run_t0 + self.stdout.write(self.style.SUCCESS( + f"Done. Migrated {i:,} endpoints in {self._fmt_duration(elapsed)} " + f"({(i / elapsed if elapsed else 0):.2f} endpoints/sec).", + )) + self._print_benchmark_summary(i, elapsed) + + if self.query_count: + connection.force_debug_cursor = False From e97662d6acf8c3687664c031b76a01d4ab08ae11 Mon Sep 17 00:00:00 2001 From: Cody Maffucci <46459665+Maffooch@users.noreply.github.com> Date: Mon, 11 May 2026 14:23:19 -0600 Subject: [PATCH 2/2] :sparkles: migrate_endpoints_to_locations: tag inheritance + per-endpoint error isolation MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit `bulk_create` (introduced in the prior perf commit) skips the `inherit_tags_on_linked_instance` post_save signal, so deployments with `enable_product_tag_inheritance` enabled (per-product or system-wide) would not pick up inherited product tags on migrated Locations. Track (product, location) pairs during the main loop — covering both `endpoint.product` and `finding.test.engagement.product` — and run a post-pass that calls `LocationManager(product)._bulk_inherit_tags(locations)` once per contributing product. The helper rediscovers each location's full product set via LocationProductReference/LocationFindingReference and diff-checks before writing, so revisits of shared locations across product groups are idempotent. ~5 queries per product group vs ~3 per location for a per-location `inherit_tags()` loop. Also wrap the per-endpoint body in a `try`/`except Exception` so a single bad row doesn't abort a multi-hour migration. Failures get logged with full traceback and tracked in `self.failed_endpoints`; the final "Done." line reports `/` and a yellow warning lists the first 10 failing IDs. `KeyboardInterrupt` / `SystemExit` are not swallowed. The post-pass uses the same pattern per product group. Co-Authored-By: Claude Opus 4.7 (1M context) --- .../migrate_endpoints_to_locations.py | 145 +++++++++++++++++- 1 file changed, 138 insertions(+), 7 deletions(-) diff --git a/dojo/management/commands/migrate_endpoints_to_locations.py b/dojo/management/commands/migrate_endpoints_to_locations.py index e980a5c92b1..081f2077024 100644 --- a/dojo/management/commands/migrate_endpoints_to_locations.py +++ b/dojo/management/commands/migrate_endpoints_to_locations.py @@ -2,6 +2,7 @@ import datetime import logging import time +from collections import defaultdict from django.core.management.base import BaseCommand from django.db import connection @@ -10,7 +11,7 @@ from dojo.location.models import Location, LocationFindingReference, LocationProductReference from dojo.location.status import FindingLocationStatus, ProductLocationStatus -from dojo.models import DojoMeta, Endpoint, Endpoint_Status +from dojo.models import DojoMeta, Endpoint, Endpoint_Status, Product from dojo.url.models import URL logger = logging.getLogger(__name__) @@ -98,6 +99,29 @@ def _bench_end(self, phase: str, t0: float) -> None: self.timings[phase] += time.perf_counter() - t0 self.counts[phase] += 1 + # -- Tag inheritance bookkeeping ----------------------------------------- + + def _track_product_location(self, product: Product, location: Location) -> None: + """ + Record a (product, location) pair for the post-migration tag inheritance pass. + + The migration creates locations that may be linked to multiple products + (via the endpoint's own product and via each finding's product). We + collect every contributing product per location so the post-pass can + call ``LocationManager(product)._bulk_inherit_tags(locations)`` once + per product group — covering the case where a location is shared + across products with differing ``enable_product_tag_inheritance`` + flags (the helper short-circuits via its own diff check on repeat + visits, so redundancy is safe). + """ + if product is None or product.id is None: + return + if location is None or location.id is None: + return + self.locations_by_product_id[product.id].add(location.id) + self.product_obj_by_id.setdefault(product.id, product) + self.location_obj_by_id.setdefault(location.id, location) + # -- Migration logic -------------------------------------------------- def _endpoint_to_url(self, endpoint: Endpoint) -> Location: @@ -197,6 +221,10 @@ def _associate_location_with_findings(self, endpoint: Endpoint, location: Locati if finding is None: continue product = finding.test.engagement.product + # Track this contributing product for the post-migration tag + # inheritance pass (covers the case where a finding's product + # differs from endpoint.product). + self._track_product_location(product, location) status = self._convert_endpoint_status_to_string_status(endpoint_status) # Endpoint_Status.date is a Date; the original code persisted # the same midnight-aware datetime in a post-save UPDATE. We @@ -298,6 +326,55 @@ def _print_benchmark_summary(self, total_endpoints: int, total_seconds: float) - f"{(total_seconds * 1000.0 / total_endpoints if total_endpoints else 0):>18.2f}" f"{'100.0%':>10}") + # -- Post-migration tag inheritance -------------------------------------- + + def _run_tag_inheritance(self) -> None: + """ + Drive `LocationManager._bulk_inherit_tags` once per contributing product. + + Each `LocationManager` call is wrapped in its own try/except so a + failure on one product group doesn't prevent the rest from running — + same philosophy as the per-endpoint loop. Tag inheritance is a + purely additive post-pass; the underlying location/reference rows + are already committed by the main loop, so partial failure here + leaves a consistent (if incomplete-inheritance) state that a + targeted re-run can finish. + """ + if not self.locations_by_product_id: + return + + # Lazy import: dojo.importers.* pulls in a lot of modules and we + # don't want it loaded at management-command discovery time. + from dojo.importers.location_manager import LocationManager # noqa: PLC0415 + + t0 = time.time() + n_products = len(self.locations_by_product_id) + n_pairs = sum(len(loc_ids) for loc_ids in self.locations_by_product_id.values()) + n_unique_locations = len(self.location_obj_by_id) + n_failures = 0 + for prod_id, loc_ids in self.locations_by_product_id.items(): + product = self.product_obj_by_id[prod_id] + locations = [self.location_obj_by_id[lid] for lid in loc_ids] + try: + LocationManager(product)._bulk_inherit_tags(locations) + except Exception: + logger.exception( + "Tag inheritance pass failed for product id=%s " + "(%d location(s)); continuing with remaining products", + prod_id, len(locations), + ) + n_failures += 1 + elapsed = time.time() - t0 + msg = ( + f"Tag inheritance pass: visited {n_pairs:,} (product, location) pair(s) " + f"across {n_products:,} product(s), {n_unique_locations:,} unique location(s), " + f"in {elapsed:.2f}s" + ) + if n_failures: + self.stdout.write(self.style.WARNING(f"{msg} — {n_failures} product group(s) failed")) + else: + self.stdout.write(self.style.SUCCESS(msg)) + # -- handle --------------------------------------------------------------- def handle(self, *args, **options): @@ -310,6 +387,21 @@ def handle(self, *args, **options): self.timings = dict.fromkeys(PHASES, 0.0) self.counts = dict.fromkeys(PHASES, 0) + # Bookkeeping for the post-migration tag inheritance pass. + # `locations_by_product_id` maps product.id -> set of location.ids + # contributed by that product (via endpoint.product OR finding.test. + # engagement.product). We hold the Product/Location objects in + # parallel maps so the post-pass can hand them directly to + # `LocationManager(product)._bulk_inherit_tags(locations)` without + # extra DB lookups. + self.locations_by_product_id: dict[int, set[int]] = defaultdict(set) + self.product_obj_by_id: dict[int, Product] = {} + self.location_obj_by_id: dict[int, Location] = {} + + # Collected per-endpoint failures so a single bad row doesn't abort + # a multi-hour migration. Each entry is (endpoint_id, exception_str). + self.failed_endpoints: list[tuple[int | None, str]] = [] + if self.query_count: connection.force_debug_cursor = True queries_at_chunk_start = len(connection.queries) @@ -360,11 +452,29 @@ def handle(self, *args, **options): # prefetch will start incrementing. self._bench_end("fetch_endpoint", t_fetch) - # Get the URL object first - location = self._endpoint_to_url(endpoint) - # Associate the URL with the findings associated with the Findings - # the association to a finding will also apply to a product automatically - self._associate_location_with_findings(endpoint, location) + # Wrap the per-endpoint work so one failure doesn't abort a + # multi-hour migration. We log the full traceback and record + # the endpoint id, then continue. The bulk_create-based hot + # path makes partial-state on failure unlikely (each phase + # is its own bulk insert), and any rows that DID land remain + # valid and idempotent on re-run. + try: + # Get the URL object first + location = self._endpoint_to_url(endpoint) + # Track the endpoint's own product as a contributor for the + # post-migration tag inheritance pass (the no-findings + # branch of _associate_location_with_findings also depends + # on this product, and it won't be tracked otherwise). + if endpoint.product_id: + self._track_product_location(endpoint.product, location) + # Associate the URL with the findings associated with the Findings + # the association to a finding will also apply to a product automatically + self._associate_location_with_findings(endpoint, location) + except Exception as exc: + endpoint_id = getattr(endpoint, "id", None) + logger.exception("Failed to migrate endpoint id=%s; continuing", endpoint_id) + self.failed_endpoints.append((endpoint_id, str(exc))) + continue # Progress report every --progress-every endpoints if i % self.progress_every == 0: @@ -378,10 +488,31 @@ def handle(self, *args, **options): self._log_progress(i, endpoint_count, run_t0, queries_in_chunk) elapsed = time.time() - run_t0 + successful = i - len(self.failed_endpoints) self.stdout.write(self.style.SUCCESS( - f"Done. Migrated {i:,} endpoints in {self._fmt_duration(elapsed)} " + f"Done. Migrated {successful:,}/{i:,} endpoints in {self._fmt_duration(elapsed)} " f"({(i / elapsed if elapsed else 0):.2f} endpoints/sec).", )) + if self.failed_endpoints: + preview_ids = [eid for eid, _ in self.failed_endpoints[:10]] + self.stdout.write(self.style.WARNING( + f"{len(self.failed_endpoints):,} endpoint(s) failed; see logger output above " + f"for tracebacks. First failing endpoint IDs: {preview_ids}", + )) + + # Run the post-migration tag inheritance pass. `bulk_create` skips + # the `inherit_tags_on_linked_instance` post_save signal, so for + # deployments with `enable_product_tag_inheritance` enabled (per + # product or system-wide) the migrated Locations would otherwise + # not pick up inherited product tags. We grouped (product, + # location) pairs during the main loop and now drive + # `LocationManager._bulk_inherit_tags` once per contributing + # product. The helper rediscovers each location's full product + # set via LocationProductReference/LocationFindingReference and + # diff-checks before writing, so revisits of shared locations + # across product groups are idempotent. + self._run_tag_inheritance() + self._print_benchmark_summary(i, elapsed) if self.query_count: