diff --git a/src/fromager/bootstrap_requirement_resolver.py b/src/fromager/bootstrap_requirement_resolver.py index e62e49a0..3cc2abdb 100644 --- a/src/fromager/bootstrap_requirement_resolver.py +++ b/src/fromager/bootstrap_requirement_resolver.py @@ -7,9 +7,11 @@ from __future__ import annotations import logging +import threading import typing from packaging.requirements import Requirement +from packaging.utils import NormalizedName, canonicalize_name from packaging.version import Version from . import finders, resolver, sources, wheels @@ -59,14 +61,18 @@ def __init__( self.prev_graph = prev_graph self.multiple_versions = multiple_versions self.cache_wheel_server_url = cache_wheel_server_url - # Session-level resolution cache to avoid re-resolving same requirements - # Key: (requirement_string, pre_built) to distinguish source vs prebuilt - # Value: tuple of (url, version) tuples sorted by version (highest first) - # Values are stored as immutable tuples to prevent accidental corruption - # when callers modify the returned reference. - self._resolved_requirements: dict[ - tuple[str, bool], tuple[tuple[str, Version], ...] - ] = {} + # All known versions for a package, accumulated across resolution + # contexts. Versions discovered via different specifiers or req_types + # are merged so that later lookups see the widest set. + # Key: (normalized_name, pre_built) + # Value: {version: url} + self._known_versions: dict[tuple[NormalizedName, bool], dict[Version, str]] = {} + # Requirement rules already resolved from network/graph. + # Key: (str(req), pre_built) + # Prevents redundant network calls for the same specifier. + self._resolved_rules: set[tuple[str, bool]] = set() + # Protects _known_versions and _resolved_rules for thread safety. + self._lock = threading.Lock() def resolve( self, @@ -78,12 +84,16 @@ def resolve( ) -> list[tuple[str, Version]]: """Resolve package requirement to matching version(s). - Tries resolution strategies in order: - 1. Session cache (if previously resolved) - 2. Previous dependency graph - 3. PyPI resolution (source or prebuilt based on package build info) - 4. Remote wheel cache server (multi-version mode only, when age - filtering produced no candidates) + Uses a two-step strategy: + 1. If this requirement rule has not been resolved before, fetch + versions from the network (or previous graph) and extend the + package-level known-versions cache. + 2. Filter all known versions for the package by the current + requirement specifier. + + This ensures that versions discovered in one context (e.g., + top-level with cooldown bypass) are visible to later lookups + with different specifiers for the same package. Args: req: Package requirement @@ -116,12 +126,34 @@ def resolve( f"Git URL requirements must be handled by Bootstrapper: {req}" ) - # Check session cache (keyed by requirement + pre_built) - cached_result = self.get_cached_resolution(req, pre_built) - if cached_result is not None: - logger.debug(f"resolved {req} from cache") - return list(cached_result) if return_all_versions else [cached_result[0]] + rule_key = (str(req), pre_built) + + with self._lock: + if rule_key not in self._resolved_rules: + # Rule not seen before — resolve from graph or network and + # extend the package-level known-versions cache. + self._resolve_and_extend(req, req_type, pre_built, parent_req) + self._resolved_rules.add(rule_key) + else: + logger.debug(f"rule already resolved: {req}") + # Filter all known versions by the current requirement specifier. + matching = self._get_matching_versions(req, pre_built) + + if not matching: + return [] + if return_all_versions: + return matching + return [matching[0]] + + def _resolve_and_extend( + self, + req: Requirement, + req_type: RequirementType, + pre_built: bool, + parent_req: Requirement | None, + ) -> None: + """Resolve versions from graph/network and extend known versions cache.""" # Try previous dependency graph cached_resolution = self._resolve_from_graph( req=req, @@ -178,15 +210,8 @@ def resolve( "multiple" if self.multiple_versions else "single", ) - # Only cache non-empty results. if results: - self.cache_resolution(req, pre_built, results) - - if not results: - return [] - if return_all_versions: - return results - return [results[0]] + self._extend_known_versions(req, pre_built, results) def _resolve_from_cache_server(self, req: Requirement) -> list[tuple[str, Version]]: """Fall back to the remote wheel cache server for a cached version. @@ -225,46 +250,78 @@ def _resolve_from_cache_server(self, req: Requirement) -> list[tuple[str, Versio return [best] return [] - def get_cached_resolution( + def get_matching_versions( self, req: Requirement, pre_built: bool, - ) -> tuple[tuple[str, Version], ...] | None: - """Get a cached resolution result if it exists. + ) -> list[tuple[str, Version]]: + """Filter known versions by requirement specifier (thread-safe). - Returns an immutable tuple to prevent accidental cache corruption. + Returns all known versions of the package that satisfy the + requirement's version specifier, sorted highest-first. Args: - req: Package requirement to look up in cache + req: Package requirement with version specifier pre_built: Whether looking for prebuilt or source resolution Returns: - Tuple of (url, version) tuples if cached, None otherwise + List of (url, version) tuples matching the specifier. """ - cache_key = (str(req), pre_built) - return self._resolved_requirements.get(cache_key) + with self._lock: + return self._get_matching_versions(req, pre_built) - def cache_resolution( + def _get_matching_versions( + self, + req: Requirement, + pre_built: bool, + ) -> list[tuple[str, Version]]: + """Filter known versions (caller must hold ``self._lock``).""" + key = (canonicalize_name(req.name), pre_built) + versions = self._known_versions.get(key, {}) + matching = [ + (url, version) + for version, url in versions.items() + if version in req.specifier + ] + matching.sort(key=lambda x: x[1], reverse=True) + return matching + + def extend_known_versions( self, req: Requirement, pre_built: bool, result: list[tuple[str, Version]], ) -> None: - """Cache a resolution result. + """Extend the known-versions cache and mark the rule as resolved (thread-safe). - The result is stored as an immutable tuple to prevent accidental - corruption when callers modify the original list. + Merges new versions into the package-level cache. When a version + already exists, a non-empty URL takes precedence over an empty one + (graph-resolved placeholders are replaced by real download URLs). Used by Bootstrapper to cache git URL resolutions that are handled externally (outside this resolver). Args: - req: Package requirement to cache + req: Package requirement (used for name and rule tracking) pre_built: Whether this is a prebuilt or source resolution - result: List of (url, version) tuples + result: List of (url, version) tuples to add """ - cache_key = (str(req), pre_built) - self._resolved_requirements[cache_key] = tuple(result) + with self._lock: + self._extend_known_versions(req, pre_built, result) + + def _extend_known_versions( + self, + req: Requirement, + pre_built: bool, + result: list[tuple[str, Version]], + ) -> None: + """Extend known versions (caller must hold ``self._lock``).""" + key = (canonicalize_name(req.name), pre_built) + versions = self._known_versions.setdefault(key, {}) + for url, version in result: + if version not in versions or (url and not versions[version]): + versions[version] = url + self._resolved_rules.add((str(req), pre_built)) def _resolve_from_graph( self, diff --git a/src/fromager/bootstrapper.py b/src/fromager/bootstrapper.py index 10514cd0..204ad903 100644 --- a/src/fromager/bootstrapper.py +++ b/src/fromager/bootstrapper.py @@ -287,19 +287,17 @@ def resolve_versions( # Check cache first to avoid re-resolving # Git URLs are always source (not prebuilt) - cached_result = self._resolver.get_cached_resolution(req, pre_built=False) - if cached_result is not None: + cached_result = self._resolver.get_matching_versions(req, pre_built=False) + if cached_result: logger.debug(f"resolved {req} from cache") - return ( - list(cached_result) if return_all_versions else [cached_result[0]] - ) + return cached_result if return_all_versions else [cached_result[0]] logger.info("resolving source via URL, ignoring any plugins") source_url, resolved_version = self._resolve_version_from_git_url(req=req) # Cache the git URL resolution (always source, not prebuilt) # Store as list for consistency with cache structure result = [(source_url, resolved_version)] - self._resolver.cache_resolution(req, pre_built=False, result=result) + self._resolver.extend_known_versions(req, pre_built=False, result=result) return result # Git URLs always return single version # Delegate to RequirementResolver diff --git a/tests/test_bootstrap_requirement_resolver.py b/tests/test_bootstrap_requirement_resolver.py index 7a4ca3ea..4c9148c5 100644 --- a/tests/test_bootstrap_requirement_resolver.py +++ b/tests/test_bootstrap_requirement_resolver.py @@ -498,8 +498,8 @@ def test_resolve_auto_routes_to_prebuilt( # Mock resolution to return expected result (as list) mock_resolve.return_value = [ ( - "https://files.pythonhosted.org/setuptools-1.0-py3-none-any.whl", - Version("1.0"), + "https://files.pythonhosted.org/setuptools-75.0-py3-none-any.whl", + Version("75.0"), ) ] @@ -515,8 +515,8 @@ def test_resolve_auto_routes_to_prebuilt( mock_resolve.assert_called_once() assert len(results) == 1 url, version = results[0] - assert url == "https://files.pythonhosted.org/setuptools-1.0-py3-none-any.whl" - assert version == Version("1.0") + assert url == "https://files.pythonhosted.org/setuptools-75.0-py3-none-any.whl" + assert version == Version("75.0") @patch("fromager.resolver.find_all_matching_from_provider") @@ -562,43 +562,52 @@ def test_resolve_auto_routes_to_source( assert version == Version("2.0") -def test_cache_resolution_stores_immutable_tuple(tmp_context: WorkContext) -> None: - """cache_resolution() stores an immutable tuple, not the original list.""" +def test_extend_known_versions_accumulates(tmp_context: WorkContext) -> None: + """extend_known_versions() accumulates versions across calls.""" resolver = BootstrapRequirementResolver(tmp_context) - req = Requirement("mypkg>=1.0") - original = [("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))] - - resolver.cache_resolution(req, pre_built=False, result=original) - cached = resolver.get_cached_resolution(req, pre_built=False) + req1 = Requirement("mypkg>=1.0") + req2 = Requirement("mypkg>=2.0") - # Cached value should be a tuple - assert isinstance(cached, tuple) + resolver.extend_known_versions( + req1, + pre_built=False, + result=[("https://files.test/mypkg-1.5.tar.gz", Version("1.5"))], + ) + resolver.extend_known_versions( + req2, + pre_built=False, + result=[("https://files.test/mypkg-2.0.tar.gz", Version("2.0"))], + ) - # Mutating the original list must not affect the cache - original.append(("https://example.com/mypkg-2.0.tar.gz", Version("2.0"))) - cached_after = resolver.get_cached_resolution(req, pre_built=False) - assert cached_after is not None - assert len(cached_after) == 1 + # Both versions are now available when filtering by the wider specifier + matching = resolver.get_matching_versions(req1, pre_built=False) + assert len(matching) == 2 + assert matching[0][1] == Version("2.0") + assert matching[1][1] == Version("1.5") -def test_get_cached_resolution_returns_immutable(tmp_context: WorkContext) -> None: - """get_cached_resolution() returns a tuple that cannot be mutated.""" +def test_get_matching_versions_returns_independent_lists( + tmp_context: WorkContext, +) -> None: + """get_matching_versions() returns a new list each call.""" resolver = BootstrapRequirementResolver(tmp_context) req = Requirement("mypkg>=1.0") - resolver.cache_resolution( + resolver.extend_known_versions( req, pre_built=False, - result=[("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))], + result=[("https://files.test/mypkg-1.0.tar.gz", Version("1.0"))], ) - cached = resolver.get_cached_resolution(req, pre_built=False) - assert cached is not None + list1 = resolver.get_matching_versions(req, pre_built=False) + list2 = resolver.get_matching_versions(req, pre_built=False) - with pytest.raises(AttributeError): - cached.append(("https://example.com/bad.tar.gz", Version("2.0"))) # type: ignore[attr-defined, union-attr] + assert list1 == list2 + assert list1 is not list2 - with pytest.raises(TypeError): - cached[0] = ("https://example.com/bad.tar.gz", Version("2.0")) # type: ignore[index] + # Mutating one does not affect the other + list1.append(("https://files.test/bad.tar.gz", Version("9.9"))) + list3 = resolver.get_matching_versions(req, pre_built=False) + assert len(list3) == 1 @patch("fromager.resolver.find_all_matching_from_provider") @@ -642,6 +651,52 @@ def test_resolve_cache_returns_independent_lists( mock_resolve.assert_called_once() +@patch("fromager.resolver.find_all_matching_from_provider") +def test_toplevel_version_visible_to_transitive_via_cache( + mock_resolve: MagicMock, + tmp_context: WorkContext, +) -> None: + """Version found via top-level bypass is visible to later transitive lookups. + + Simulates the architect's scenario: + 1. Top-level A==2.0.0 resolves with cooldown bypassed → v2.0 cached + 2. Transitive A>=1.0 resolves with cooldown enforced → only v1.5 from network + 3. But v2.0 is already in the package-level cache, so the transitive + lookup returns v2.0 as the highest matching version. + """ + resolver = BootstrapRequirementResolver(tmp_context) + + # Step 1: top-level resolves A==2.0.0 (cooldown bypassed) + mock_resolve.return_value = [ + ("https://files.test/mypkg-2.0.tar.gz", Version("2.0")), + ] + results_toplevel = resolver.resolve( + req=Requirement("mypkg==2.0"), + req_type=RequirementType.TOP_LEVEL, + parent_req=None, + pre_built=False, + ) + assert results_toplevel[0][1] == Version("2.0") + assert mock_resolve.call_count == 1 + + # Step 2: transitive resolves A>=1.0 (cooldown enforced, only finds v1.5) + mock_resolve.return_value = [ + ("https://files.test/mypkg-1.5.tar.gz", Version("1.5")), + ] + results_transitive = resolver.resolve( + req=Requirement("mypkg>=1.0"), + req_type=RequirementType.INSTALL, + parent_req=None, + pre_built=False, + ) + assert mock_resolve.call_count == 2 + + # The transitive lookup sees v2.0 from the package-level cache (found + # during top-level resolution) even though the transitive network call + # only returned v1.5. + assert results_transitive[0][1] == Version("2.0") + + @patch("fromager.resolver.find_all_matching_from_provider") def test_resolve_prebuilt_after_source_uses_separate_cache( mock_resolve: MagicMock,