Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
68 changes: 65 additions & 3 deletions xrspatial/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -735,6 +735,38 @@ def _process_proximity_line(
return


def _kdtree_query_lowest_index(tree, query_pts, p, max_distance):
"""Nearest-target query that breaks ties by lowest target index.

``cKDTree.query`` does not promise which of several equidistant targets
it returns, so allocation and direction can disagree with the brute-force
and CUDA backends on a tie. Target coordinates are stored in row-major
(flat-index) order, so the lowest target index is the lowest flat index --
the tie-break policy documented on ``allocation``/``direction``.

Query the two nearest targets; wherever they are equidistant, keep the one
with the smaller index. This resolves 2-way ties, which is what grid
geometry produces in practice. A pixel equidistant to three or more targets
relies on cKDTree returning the lower index among the rest, which it does
for the row-major target order used here but does not strictly promise.
"""
n_targets = tree.n
if n_targets < 2:
return tree.query(query_pts, p=p, distance_upper_bound=max_distance)

dists2, idx2 = tree.query(query_pts, k=2, p=p,
distance_upper_bound=max_distance)
dists = dists2[:, 0]
indices = idx2[:, 0]
# A tie exists where both neighbours are finite and equidistant. Prefer the
# smaller index in that case so the result is independent of cKDTree's
# internal traversal order.
tied = np.isfinite(dists2[:, 1]) & (dists2[:, 1] == dists)
if tied.any():
indices = np.where(tied, np.minimum(idx2[:, 0], idx2[:, 1]), indices)
return dists, indices


def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d,
tree, block_info, max_distance, p,
process_mode, target_vals, target_coords):
Expand All @@ -751,8 +783,8 @@ def _kdtree_chunk_fn(block, y_coords_1d, x_coords_1d,
yy, xx = np.meshgrid(chunk_ys, chunk_xs, indexing='ij')
query_pts = np.column_stack([yy.ravel(), xx.ravel()])

dists, indices = tree.query(query_pts, p=p,
distance_upper_bound=max_distance)
dists, indices = _kdtree_query_lowest_index(
tree, query_pts, p, max_distance)

n_targets = len(target_vals)
oob = indices >= n_targets
Expand Down Expand Up @@ -979,7 +1011,7 @@ def _tiled_chunk_query(raster, iy, ix, y_coords, x_coords,

tree = cKDTree(target_coords)
ub = max_distance if np.isfinite(max_distance) else np.inf
dists, indices = tree.query(query_pts, p=p, distance_upper_bound=ub)
dists, indices = _kdtree_query_lowest_index(tree, query_pts, p, ub)

n_targets = len(target_vals)
oob = indices >= n_targets
Expand Down Expand Up @@ -1374,6 +1406,22 @@ def _process_numpy(img, x_coords, y_coords):
img, x_coords, y_coords, target_values,
np.float32(max_distance), distance_metric, process_mode,
)
# ALLOCATION and DIRECTION pick a single target per pixel, so a tie
# between two equidistant targets must resolve the same way on every
# backend. The line-sweep breaks ties as a side effect of its
# four-pass propagation order, which disagrees with the brute-force
# and CUDA kernels. Route those modes through the brute-force search,
# which keeps the lowest-flat-index target on a tie (see the
# `allocation`/`direction` docstrings). Brute force is O(N*T) versus
# the line-sweep's O(N); the slower scan is a deliberate trade for a
# tie-break that matches every other backend. PROXIMITY only returns
# the distance, which is identical for tied targets, so the faster
# line-sweep stays in use there.
if process_mode in (ALLOCATION, DIRECTION):
return _process_numpy_bruteforce(
img, x_coords, y_coords, target_values,
np.float32(max_distance), distance_metric, process_mode,
)
return _process_numpy_linesweep(img, x_coords, y_coords)

def _process_dask(raster, xs, ys):
Expand Down Expand Up @@ -1680,6 +1728,13 @@ def allocation(
`allocation` chunk by chunk by expanding the chunk's borders to cover
the `max_distance`.

Tie-breaking: when two or more targets are exactly equidistant from a
pixel, the target with the lowest flat (row-major) index wins, i.e. the
first target encountered when scanning the raster top-to-bottom and
left-to-right. This policy is identical across all backends (numpy, cupy,
dask+numpy, dask+cupy), so the allocated value is deterministic regardless
of which backend computes it.

Parameters
----------
raster : xr.DataArray or xr.Dataset
Expand Down Expand Up @@ -1831,6 +1886,13 @@ def direction(
proximity direction chunk by chunk by expanding the chunk's borders
to cover the `max_distance`.

Tie-breaking: when two or more targets are exactly equidistant from a
pixel, the direction is computed toward the target with the lowest flat
(row-major) index, i.e. the first target encountered when scanning the
raster top-to-bottom and left-to-right. This policy is identical across
all backends (numpy, cupy, dask+numpy, dask+cupy), so the reported
direction is deterministic regardless of which backend computes it.

Parameters
----------
raster : xr.DataArray or xr.Dataset
Expand Down
86 changes: 86 additions & 0 deletions xrspatial/tests/test_proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,6 +535,58 @@ def test_proximity_dask_kdtree_with_target_values():
)


# ---------------------------------------------------------------------------
# Tie-breaking: when two targets are equidistant, every backend must pick the
# same one. The documented policy is "lowest flat (row-major) index wins".
# ---------------------------------------------------------------------------

@pytest.fixture
def tie_break_raster_data():
# Targets at (1, 0)=1 and (1, 2)=2. The whole centre column is equidistant
# to both. Target 1 sits at flat index 3, target 2 at flat index 5, so the
# lowest-flat-index policy allocates the centre column to 1.
return np.array([[0., 0., 0.],
[1., 0., 2.],
[0., 0., 0.]], dtype=np.float64)


@pytest.fixture
def tie_break_expected_allocation():
return np.array([[1., 1., 2.],
[1., 1., 2.],
[1., 1., 2.]], dtype=np.float32)


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy', 'dask+cupy'])
def test_allocation_tie_break_lowest_flat_index(
backend, tie_break_raster_data, tie_break_expected_allocation):
raster = create_test_raster(
tie_break_raster_data, backend=backend, dims=['lat', 'lon'],
chunks=(1, 1),
)
result = allocation(raster, x='lon', y='lat')
general_output_checks(
raster, result, tie_break_expected_allocation, verify_dtype=True,
)


@pytest.mark.parametrize("backend", ['numpy', 'dask+numpy', 'cupy', 'dask+cupy'])
def test_direction_tie_break_matches_numpy(backend, tie_break_raster_data):
# The numpy backend is the reference. Pin every other backend to it so the
# direction angle chosen on a tie stays identical across backends.
numpy_raster = create_test_raster(
tie_break_raster_data, backend='numpy', dims=['lat', 'lon'],
)
expected = direction(numpy_raster, x='lon', y='lat').data

raster = create_test_raster(
tie_break_raster_data, backend=backend, dims=['lat', 'lon'],
chunks=(1, 1),
)
result = direction(raster, x='lon', y='lat')
general_output_checks(raster, result, expected)


@pytest.mark.skipif(da is None, reason="dask is not installed")
def test_proximity_dask_kdtree_no_targets():
"""No target pixels found → result is all NaN."""
Expand Down Expand Up @@ -704,6 +756,40 @@ def test_proximity_dask_kdtree_tiled_manhattan():
)


@pytest.mark.skipif(da is None, reason="dask is not installed")
def test_allocation_tie_break_tiled_path():
"""The eager tiled KDTree fallback obeys the lowest-flat-index tie-break."""
data = np.array([[0., 0., 0.],
[1., 0., 2.],
[0., 0., 0.]], dtype=np.float64)
_lon = np.array([0., 1., 2.])
_lat = np.array([2., 1., 0.])
raster = xr.DataArray(data, dims=['lat', 'lon'])
raster['lon'] = _lon
raster['lat'] = _lat
raster.data = da.from_array(data, chunks=(1, 1))

# Force the eager tiled KDTree path (see _force_tiled_proximity for the
# counter semantics): tiny cache budget, force the tiled decision, then a
# large value so the result-size guard passes.
call_count = [0]

def _small_then_large():
call_count[0] += 1
if call_count[0] <= 2:
return 1
return 10 * 1024 ** 3

with patch('xrspatial.proximity._available_memory_bytes',
side_effect=_small_then_large):
result = allocation(raster, x='lon', y='lat')

expected = np.array([[1., 1., 2.],
[1., 1., 2.],
[1., 1., 2.]], dtype=np.float32)
np.testing.assert_array_equal(result.values, expected)


@pytest.mark.skipif(da is None, reason="dask is not installed")
def test_proximity_dask_kdtree_tiled_single_target():
"""One target in a corner, many chunks → exercises max ring expansion."""
Expand Down
Loading