Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 44 additions & 1 deletion xrspatial/proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -304,6 +304,43 @@ def _min_step_distance(coords, x_ref, y_ref, along):
return pad_y, pad_x


def _fit_halo_to_chunks(pad_y, pad_x, *arrays):
"""Make a bounded halo depth that ``da.map_overlap`` will accept.

``map_overlap`` rejects a depth that is larger than the smallest chunk
along that axis. The pixel halo from ``_halo_depth`` can exceed that on
skinny rasters (e.g. a 3-row raster with a 10-pixel halo), or on
great-circle rasters where the pixel pitch shrinks toward the poles. When
a halo is too deep for the chunking, fold that whole axis into a single
chunk and drop its depth to zero: every chunk then sees the full axis, so
no target within ``max_distance`` is missed and the result still matches
the NumPy backend. The fold deliberately trades chunking on that axis for
correctness; clamping the depth while keeping multiple chunks would
silently drop targets that fall in a non-adjacent chunk, so do not replace
it with a bare depth clamp.

``arrays`` are the dask arrays passed to the same ``map_overlap`` call
(the raster and the coordinate grids); they all share the raster's
chunking and are rechunked together so the call stays aligned.

Returns the adjusted ``(pad_y, pad_x)`` and the (possibly rechunked)
arrays in the same order they were given.
"""
arrays = list(arrays)
height, width = arrays[0].shape
chunks_y, chunks_x = arrays[0].chunks
rechunk = {}
if pad_y > min(chunks_y):
rechunk[0] = height
pad_y = 0
if pad_x > min(chunks_x):
rechunk[1] = width
pad_x = 0
if rechunk:
arrays = [a.rechunk(rechunk) for a in arrays]
return pad_y, pad_x, arrays


@ngjit
def _calc_direction(x1, x2, y1, y2):
# Calculate direction from (x1, y1) to a source cell (x2, y2).
Expand Down Expand Up @@ -574,6 +611,10 @@ def _process_dask_cupy(raster, x_coords, y_coords, target_values,
ys = da.repeat(y_da, raster.shape[1]).reshape(
raster.shape).rechunk(raster.data.chunks)

# Keep the overlap depth within what map_overlap accepts on skinny rasters.
pad_y, pad_x, (raster_data, xs, ys) = _fit_halo_to_chunks(
pad_y, pad_x, raster.data, xs, ys)

# Capture closure vars for the chunk function
tv = target_values
md = max_distance
Expand All @@ -588,7 +629,7 @@ def _chunk_func(data_chunk, xs_chunk, ys_chunk):

return da.map_overlap(
_chunk_func,
raster.data, xs, ys,
raster_data, xs, ys,
depth=(pad_y, pad_x),
boundary=np.nan,
meta=cp.array((), dtype=cp.float32),
Expand Down Expand Up @@ -1390,6 +1431,8 @@ def _process_dask(raster, xs, ys):
else:
pad_y, pad_x = _halo_depth(
x_coords, y_coords, max_distance, distance_metric)
pad_y, pad_x, (raster.data, xs, ys) = _fit_halo_to_chunks(
pad_y, pad_x, raster.data, xs, ys)

out = da.map_overlap(
_process_numpy,
Expand Down
34 changes: 34 additions & 0 deletions xrspatial/tests/test_proximity.py
Original file line number Diff line number Diff line change
Expand Up @@ -1440,6 +1440,40 @@ def test_bounded_dask_single_row_or_col_matches_numpy(func, shape_name):
result.values, expected, equal_nan=True, rtol=1e-5)


# --- issue #2854: bounded-dask halo depth larger than an axis length -------


@pytest.mark.skipif(da is None, reason="dask is not installed")
@pytest.mark.parametrize("func", [proximity, allocation, direction])
def test_bounded_dask_skinny_raster_matches_numpy(func):
"""Bounded dask must not crash when the halo is deeper than an axis.

Regression for issue #2854: on a skinny raster ``_halo_depth`` can return
a pixel radius larger than the raster height/width. That depth went
straight into ``da.map_overlap``, which rejects a depth larger than the
array along that axis and raised ``ValueError: The overlapping depth ...
is larger than your array ...``. A valid raster with a finite
``max_distance`` should still run and match the numpy backend.
"""
data = np.zeros((3, 100), dtype=np.float64)
data[1, 50] = 1.0
xs = np.linspace(0, 99, 100)
ys = np.linspace(0, 2, 3)

raster = xr.DataArray(data, dims=['lat', 'lon'])
raster['lon'] = xs
raster['lat'] = ys
expected = func(raster, x='lon', y='lat', max_distance=10).data

dask_raster = raster.copy()
dask_raster.data = da.from_array(data, chunks=(3, 100))
result = func(dask_raster, x='lon', y='lat', max_distance=10)

assert isinstance(result.data, da.Array)
np.testing.assert_allclose(
result.values, expected, equal_nan=True, rtol=1e-5)


@pytest.mark.parametrize("func", [proximity, allocation, direction])
def test_target_values_none_default_matches_empty_list(func):
# target_values default switched from [] to a None sentinel; passing
Expand Down
Loading