Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
122 changes: 78 additions & 44 deletions src/fromager/bootstrap_requirement_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import typing

from packaging.requirements import Requirement
from packaging.utils import NormalizedName, canonicalize_name
from packaging.version import Version

from . import finders, resolver, sources, wheels
Expand Down Expand Up @@ -59,14 +60,16 @@ def __init__(
self.prev_graph = prev_graph
self.multiple_versions = multiple_versions
self.cache_wheel_server_url = cache_wheel_server_url
# Session-level resolution cache to avoid re-resolving same requirements
# Key: (requirement_string, pre_built) to distinguish source vs prebuilt
# Value: tuple of (url, version) tuples sorted by version (highest first)
# Values are stored as immutable tuples to prevent accidental corruption
# when callers modify the returned reference.
self._resolved_requirements: dict[
tuple[str, bool], tuple[tuple[str, Version], ...]
] = {}
# All known versions for a package, accumulated across resolution
# contexts. Versions discovered via different specifiers or req_types
# are merged so that later lookups see the widest set.
# Key: (normalized_name, pre_built)
# Value: {version: url}
self._known_versions: dict[tuple[NormalizedName, bool], dict[Version, str]] = {}
# Requirement rules already resolved from network/graph.
# Key: (str(req), pre_built)
# Prevents redundant network calls for the same specifier.
self._resolved_rules: set[tuple[str, bool]] = set()

def resolve(
self,
Expand All @@ -78,12 +81,16 @@ def resolve(
) -> list[tuple[str, Version]]:
"""Resolve package requirement to matching version(s).

Tries resolution strategies in order:
1. Session cache (if previously resolved)
2. Previous dependency graph
3. PyPI resolution (source or prebuilt based on package build info)
4. Remote wheel cache server (multi-version mode only, when age
filtering produced no candidates)
Uses a two-step strategy:
1. If this requirement rule has not been resolved before, fetch
versions from the network (or previous graph) and extend the
package-level known-versions cache.
2. Filter all known versions for the package by the current
requirement specifier.

This ensures that versions discovered in one context (e.g.,
top-level with cooldown bypass) are visible to later lookups
with different specifiers for the same package.

Args:
req: Package requirement
Expand Down Expand Up @@ -116,12 +123,33 @@ def resolve(
f"Git URL requirements must be handled by Bootstrapper: {req}"
)

# Check session cache (keyed by requirement + pre_built)
cached_result = self.get_cached_resolution(req, pre_built)
if cached_result is not None:
logger.debug(f"resolved {req} from cache")
return list(cached_result) if return_all_versions else [cached_result[0]]
rule_key = (str(req), pre_built)

if rule_key not in self._resolved_rules:
# Rule not seen before — resolve from graph or network and
# extend the package-level known-versions cache.
self._resolve_and_extend(req, req_type, pre_built, parent_req)
self._resolved_rules.add(rule_key)
else:
logger.debug(f"rule already resolved: {req}")

# Filter all known versions by the current requirement specifier.
matching = self.get_matching_versions(req, pre_built)

if not matching:
return []
if return_all_versions:
return matching
return [matching[0]]

def _resolve_and_extend(
self,
req: Requirement,
req_type: RequirementType,
pre_built: bool,
parent_req: Requirement | None,
) -> None:
"""Resolve versions from graph/network and extend known versions cache."""
# Try previous dependency graph
cached_resolution = self._resolve_from_graph(
req=req,
Expand Down Expand Up @@ -178,15 +206,8 @@ def resolve(
"multiple" if self.multiple_versions else "single",
)

# Only cache non-empty results.
if results:
self.cache_resolution(req, pre_built, results)

if not results:
return []
if return_all_versions:
return results
return [results[0]]
self.extend_known_versions(req, pre_built, results)

def _resolve_from_cache_server(self, req: Requirement) -> list[tuple[str, Version]]:
"""Fall back to the remote wheel cache server for a cached version.
Expand Down Expand Up @@ -225,46 +246,59 @@ def _resolve_from_cache_server(self, req: Requirement) -> list[tuple[str, Versio
return [best]
return []

def get_cached_resolution(
def get_matching_versions(
self,
req: Requirement,
pre_built: bool,
) -> tuple[tuple[str, Version], ...] | None:
"""Get a cached resolution result if it exists.
) -> list[tuple[str, Version]]:
"""Filter known versions by requirement specifier.

Returns an immutable tuple to prevent accidental cache corruption.
Returns all known versions of the package that satisfy the
requirement's version specifier, sorted highest-first.

Args:
req: Package requirement to look up in cache
req: Package requirement with version specifier
pre_built: Whether looking for prebuilt or source resolution

Returns:
Tuple of (url, version) tuples if cached, None otherwise
List of (url, version) tuples matching the specifier.
"""
cache_key = (str(req), pre_built)
return self._resolved_requirements.get(cache_key)

def cache_resolution(
key = (canonicalize_name(req.name), pre_built)
versions = self._known_versions.get(key, {})
matching = [
(url, version)
for version, url in versions.items()
if version in req.specifier
]
matching.sort(key=lambda x: x[1], reverse=True)
return matching

def extend_known_versions(
self,
req: Requirement,
pre_built: bool,
result: list[tuple[str, Version]],
) -> None:
"""Cache a resolution result.
"""Extend the known-versions cache and mark the rule as resolved.

The result is stored as an immutable tuple to prevent accidental
corruption when callers modify the original list.
Merges new versions into the package-level cache. When a version
already exists, a non-empty URL takes precedence over an empty one
(graph-resolved placeholders are replaced by real download URLs).

Used by Bootstrapper to cache git URL resolutions that are
handled externally (outside this resolver).

Args:
req: Package requirement to cache
req: Package requirement (used for name and rule tracking)
pre_built: Whether this is a prebuilt or source resolution
result: List of (url, version) tuples
result: List of (url, version) tuples to add
"""
cache_key = (str(req), pre_built)
self._resolved_requirements[cache_key] = tuple(result)
key = (canonicalize_name(req.name), pre_built)
versions = self._known_versions.setdefault(key, {})
for url, version in result:
if version not in versions or (url and not versions[version]):
versions[version] = url
self._resolved_rules.add((str(req), pre_built))

def _resolve_from_graph(
self,
Expand Down
10 changes: 4 additions & 6 deletions src/fromager/bootstrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,19 +287,17 @@ def resolve_versions(

# Check cache first to avoid re-resolving
# Git URLs are always source (not prebuilt)
cached_result = self._resolver.get_cached_resolution(req, pre_built=False)
if cached_result is not None:
cached_result = self._resolver.get_matching_versions(req, pre_built=False)
if cached_result:
logger.debug(f"resolved {req} from cache")
return (
list(cached_result) if return_all_versions else [cached_result[0]]
)
return cached_result if return_all_versions else [cached_result[0]]

logger.info("resolving source via URL, ignoring any plugins")
source_url, resolved_version = self._resolve_version_from_git_url(req=req)
# Cache the git URL resolution (always source, not prebuilt)
# Store as list for consistency with cache structure
result = [(source_url, resolved_version)]
self._resolver.cache_resolution(req, pre_built=False, result=result)
self._resolver.extend_known_versions(req, pre_built=False, result=result)
return result # Git URLs always return single version

# Delegate to RequirementResolver
Expand Down
111 changes: 83 additions & 28 deletions tests/test_bootstrap_requirement_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -498,8 +498,8 @@ def test_resolve_auto_routes_to_prebuilt(
# Mock resolution to return expected result (as list)
mock_resolve.return_value = [
(
"https://files.pythonhosted.org/setuptools-1.0-py3-none-any.whl",
Version("1.0"),
"https://files.pythonhosted.org/setuptools-75.0-py3-none-any.whl",
Version("75.0"),
)
]

Expand All @@ -515,8 +515,8 @@ def test_resolve_auto_routes_to_prebuilt(
mock_resolve.assert_called_once()
assert len(results) == 1
url, version = results[0]
assert url == "https://files.pythonhosted.org/setuptools-1.0-py3-none-any.whl"
assert version == Version("1.0")
assert url == "https://files.pythonhosted.org/setuptools-75.0-py3-none-any.whl"
assert version == Version("75.0")


@patch("fromager.resolver.find_all_matching_from_provider")
Expand Down Expand Up @@ -562,43 +562,52 @@ def test_resolve_auto_routes_to_source(
assert version == Version("2.0")


def test_cache_resolution_stores_immutable_tuple(tmp_context: WorkContext) -> None:
"""cache_resolution() stores an immutable tuple, not the original list."""
def test_extend_known_versions_accumulates(tmp_context: WorkContext) -> None:
"""extend_known_versions() accumulates versions across calls."""
resolver = BootstrapRequirementResolver(tmp_context)
req = Requirement("mypkg>=1.0")
original = [("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))]

resolver.cache_resolution(req, pre_built=False, result=original)
cached = resolver.get_cached_resolution(req, pre_built=False)
req1 = Requirement("mypkg>=1.0")
req2 = Requirement("mypkg>=2.0")

# Cached value should be a tuple
assert isinstance(cached, tuple)
resolver.extend_known_versions(
req1,
pre_built=False,
result=[("https://files.test/mypkg-1.5.tar.gz", Version("1.5"))],
)
resolver.extend_known_versions(
req2,
pre_built=False,
result=[("https://files.test/mypkg-2.0.tar.gz", Version("2.0"))],
)

# Mutating the original list must not affect the cache
original.append(("https://example.com/mypkg-2.0.tar.gz", Version("2.0")))
cached_after = resolver.get_cached_resolution(req, pre_built=False)
assert cached_after is not None
assert len(cached_after) == 1
# Both versions are now available when filtering by the wider specifier
matching = resolver.get_matching_versions(req1, pre_built=False)
assert len(matching) == 2
assert matching[0][1] == Version("2.0")
assert matching[1][1] == Version("1.5")


def test_get_cached_resolution_returns_immutable(tmp_context: WorkContext) -> None:
"""get_cached_resolution() returns a tuple that cannot be mutated."""
def test_get_matching_versions_returns_independent_lists(
tmp_context: WorkContext,
) -> None:
"""get_matching_versions() returns a new list each call."""
resolver = BootstrapRequirementResolver(tmp_context)
req = Requirement("mypkg>=1.0")

resolver.cache_resolution(
resolver.extend_known_versions(
req,
pre_built=False,
result=[("https://example.com/mypkg-1.0.tar.gz", Version("1.0"))],
result=[("https://files.test/mypkg-1.0.tar.gz", Version("1.0"))],
)
cached = resolver.get_cached_resolution(req, pre_built=False)
assert cached is not None
list1 = resolver.get_matching_versions(req, pre_built=False)
list2 = resolver.get_matching_versions(req, pre_built=False)

with pytest.raises(AttributeError):
cached.append(("https://example.com/bad.tar.gz", Version("2.0"))) # type: ignore[attr-defined, union-attr]
assert list1 == list2
assert list1 is not list2

with pytest.raises(TypeError):
cached[0] = ("https://example.com/bad.tar.gz", Version("2.0")) # type: ignore[index]
# Mutating one does not affect the other
list1.append(("https://files.test/bad.tar.gz", Version("9.9")))
list3 = resolver.get_matching_versions(req, pre_built=False)
assert len(list3) == 1


@patch("fromager.resolver.find_all_matching_from_provider")
Expand Down Expand Up @@ -642,6 +651,52 @@ def test_resolve_cache_returns_independent_lists(
mock_resolve.assert_called_once()


@patch("fromager.resolver.find_all_matching_from_provider")
def test_toplevel_version_visible_to_transitive_via_cache(
mock_resolve: MagicMock,
tmp_context: WorkContext,
) -> None:
"""Version found via top-level bypass is visible to later transitive lookups.

Simulates the architect's scenario:
1. Top-level A==2.0.0 resolves with cooldown bypassed → v2.0 cached
2. Transitive A>=1.0 resolves with cooldown enforced → only v1.5 from network
3. But v2.0 is already in the package-level cache, so the transitive
lookup returns v2.0 as the highest matching version.
"""
resolver = BootstrapRequirementResolver(tmp_context)

# Step 1: top-level resolves A==2.0.0 (cooldown bypassed)
mock_resolve.return_value = [
("https://files.test/mypkg-2.0.tar.gz", Version("2.0")),
]
results_toplevel = resolver.resolve(
req=Requirement("mypkg==2.0"),
req_type=RequirementType.TOP_LEVEL,
parent_req=None,
pre_built=False,
)
assert results_toplevel[0][1] == Version("2.0")
assert mock_resolve.call_count == 1

# Step 2: transitive resolves A>=1.0 (cooldown enforced, only finds v1.5)
mock_resolve.return_value = [
("https://files.test/mypkg-1.5.tar.gz", Version("1.5")),
]
results_transitive = resolver.resolve(
req=Requirement("mypkg>=1.0"),
req_type=RequirementType.INSTALL,
parent_req=None,
pre_built=False,
)
assert mock_resolve.call_count == 2

# The transitive lookup sees v2.0 from the package-level cache (found
# during top-level resolution) even though the transitive network call
# only returned v1.5.
assert results_transitive[0][1] == Version("2.0")


@patch("fromager.resolver.find_all_matching_from_provider")
def test_resolve_prebuilt_after_source_uses_separate_cache(
mock_resolve: MagicMock,
Expand Down
Loading