Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions xrspatial/hydro/flow_accumulation_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class cupy: # type: ignore[no-redef]
da = None

from xrspatial.utils import (
_validate_mfd_fractions,
_validate_raster,
cuda_args,
has_cuda_and_cupy,
Expand Down Expand Up @@ -871,6 +872,9 @@ def flow_accumulation_mfd(flow_dir_mfd: xr.DataArray,
f"got shape {data.shape}"
)

_validate_mfd_fractions(data, func_name='flow_accumulation_mfd',
name='flow_dir_mfd')

if isinstance(data, np.ndarray):
_check_memory(data.shape[1], data.shape[2])
out = _flow_accum_mfd_cpu(
Expand Down
4 changes: 4 additions & 0 deletions xrspatial/hydro/flow_length_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

from xrspatial.hydro._boundary_store import BoundaryStore
from xrspatial.utils import (
_validate_mfd_fractions,
_validate_raster,
get_dataarray_resolution,
has_cuda_and_cupy,
Expand Down Expand Up @@ -1074,6 +1075,9 @@ def flow_length_mfd(flow_dir_mfd: xr.DataArray,
f"got shape {data.shape}"
)

_validate_mfd_fractions(data, func_name='flow_length_mfd',
name='flow_dir_mfd')

cellsize_x, cellsize_y = get_dataarray_resolution(flow_dir_mfd)
if not (np.isfinite(cellsize_x) and cellsize_x != 0
and np.isfinite(cellsize_y) and cellsize_y != 0):
Expand Down
4 changes: 4 additions & 0 deletions xrspatial/hydro/flow_path_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class cupy: # type: ignore[no-redef]

from xrspatial.utils import (
_validate_matching_shape,
_validate_mfd_fractions,
_validate_raster,
has_cuda_and_cupy,
is_cupy_array,
Expand Down Expand Up @@ -453,6 +454,9 @@ def flow_path_mfd(flow_dir_mfd: xr.DataArray,
raise ValueError(
f"flow_dir_mfd must have shape (8, H, W), got {data.shape}")

_validate_mfd_fractions(data, func_name='flow_path_mfd',
name='flow_dir_mfd')

_, H, W = data.shape

_validate_matching_shape(
Expand Down
4 changes: 4 additions & 0 deletions xrspatial/hydro/hand_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
)
from xrspatial.utils import (
_validate_matching_shape,
_validate_mfd_fractions,
_validate_raster,
has_cuda_and_cupy,
is_cupy_array,
Expand Down Expand Up @@ -709,6 +710,9 @@ def hand_mfd(flow_dir_mfd: xr.DataArray,
raise ValueError(
f"flow_dir_mfd must have shape (8, H, W), got {data.shape}")

_validate_mfd_fractions(data, func_name='hand_mfd',
name='flow_dir_mfd')

_, H, W = data.shape

_validate_matching_shape(
Expand Down
3 changes: 3 additions & 0 deletions xrspatial/hydro/stream_link_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ class cupy: # type: ignore[no-redef]

from xrspatial.utils import (
_validate_matching_shape,
_validate_mfd_fractions,
_validate_raster,
cuda_args,
has_cuda_and_cupy,
Expand Down Expand Up @@ -1029,6 +1030,8 @@ def stream_link_mfd(fractions: xr.DataArray,
_validate_matching_shape(
flow_accum, data.shape[1:], func_name='stream_link_mfd',
name='flow_accum', expected_name='fractions')
_validate_mfd_fractions(data, func_name='stream_link_mfd',
name='fractions')

fa_data = flow_accum.data

Expand Down
3 changes: 3 additions & 0 deletions xrspatial/hydro/stream_order_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class cupy: # type: ignore[no-redef]

from xrspatial.utils import (
_validate_matching_shape,
_validate_mfd_fractions,
_validate_raster,
cuda_args,
has_cuda_and_cupy,
Expand Down Expand Up @@ -1516,6 +1517,8 @@ def stream_order_mfd(fractions: xr.DataArray,
_validate_matching_shape(
flow_accum, frac_data.shape[1:], func_name='stream_order_mfd',
name='flow_accum', expected_name='fractions')
_validate_mfd_fractions(frac_data, func_name='stream_order_mfd',
name='fractions')

if isinstance(frac_data, np.ndarray):
_check_memory(frac_data.shape[1], frac_data.shape[2])
Expand Down
188 changes: 188 additions & 0 deletions xrspatial/hydro/tests/test_validate_mfd_fractions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,188 @@
"""Tests for issue #2873: MFD public APIs validate fraction VALUES.

The public MFD functions document a fraction-grid contract: each cell's
8 bands are in [0, 1] and sum to either 1.0 (flow) or 0.0
(pit/flat/sink), with all-NaN bands at edges and nodata cells. Before
this change they only checked the (8, H, W) shape and ran hydrology math
on whatever values they were given.

These tests pin the value validation: negative fractions, band sums that
are neither ~1.0 nor ~0.0, and partial-NaN band patterns all raise a
clear ValueError on the in-memory (numpy / cupy) backends. Valid grids
from ``flow_direction_mfd`` still pass, and dask inputs skip the eager
value check so laziness is preserved.
"""

import numpy as np
import pytest
import xarray as xr

from xrspatial.hydro.flow_direction_mfd import flow_direction_mfd
from xrspatial.hydro.flow_accumulation_mfd import flow_accumulation_mfd
from xrspatial.hydro.flow_length_mfd import flow_length_mfd
from xrspatial.hydro.stream_order_mfd import stream_order_mfd
from xrspatial.hydro.stream_link_mfd import stream_link_mfd
from xrspatial.hydro.flow_path_mfd import flow_path_mfd
from xrspatial.hydro.hand_mfd import hand_mfd
from xrspatial.hydro.watershed_mfd import watershed_mfd


# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------

def _bowl(n=7):
"""A simple bowl whose center is the lowest cell."""
y = np.arange(n, dtype=np.float64) - n // 2
x = np.arange(n, dtype=np.float64) - n // 2
yy, xx = np.meshgrid(y, x, indexing='ij')
return xr.DataArray(yy ** 2 + xx ** 2, dims=['y', 'x'])


def _valid_mfd(n=7):
"""A valid MFD fraction grid produced by flow_direction_mfd."""
return flow_direction_mfd(_bowl(n))


def _flow_accum(mfd):
return flow_accumulation_mfd(mfd)


def _interior_flow_cell(mfd):
"""Return (r, c) of an interior cell whose bands sum to ~1.0.

Used to inject corruption into a cell the validator will inspect.
"""
vals = mfd.values
sums = np.nansum(vals, axis=0)
nan_count = np.isnan(vals).sum(axis=0)
rows, cols = np.where((nan_count == 0) & (np.abs(sums - 1.0) <= 1e-6))
assert len(rows) > 0, "test grid has no normal flow cell"
return int(rows[0]), int(cols[0])


def _corrupt(mfd, mutate):
"""Copy *mfd* and apply *mutate(values)* in place, return DataArray."""
bad = mfd.copy(deep=True)
mutate(bad.values)
return bad


# Each entry: (name, callable taking the (possibly corrupt) fraction grid)
def _callers(mfd):
fa = _flow_accum(_valid_mfd()) # valid accumulation as a secondary arg
sp = xr.DataArray(np.full(mfd.shape[1:], np.nan), dims=mfd.dims[1:])
sp.values[0, 0] = 1.0
pp = sp
elev = _bowl()
return [
("flow_accumulation_mfd", lambda g: flow_accumulation_mfd(g)),
("flow_length_mfd", lambda g: flow_length_mfd(g)),
("stream_order_mfd", lambda g: stream_order_mfd(g, fa, threshold=1)),
("stream_link_mfd", lambda g: stream_link_mfd(g, fa, threshold=1)),
("flow_path_mfd", lambda g: flow_path_mfd(g, sp)),
("hand_mfd", lambda g: hand_mfd(g, fa, elev, threshold=1)),
("watershed_mfd", lambda g: watershed_mfd(g, pp)),
]


# ---------------------------------------------------------------------------
# Valid input still passes
# ---------------------------------------------------------------------------

class TestValidInputPasses:
def test_all_consumers_accept_valid_grid(self):
mfd = _valid_mfd()
for fname, call in _callers(mfd):
# should not raise
call(mfd)


# ---------------------------------------------------------------------------
# Negative fractions
# ---------------------------------------------------------------------------

class TestNegativeFractions:
def test_each_consumer_rejects_negative(self):
mfd = _valid_mfd()
r, c = _interior_flow_cell(mfd)

def mutate(v):
v[0, r, c] = -0.5
v[1, r, c] += 0.5 # keep the band sum at 1.0

bad = _corrupt(mfd, mutate)
for fname, call in _callers(mfd):
with pytest.raises(ValueError, match="negative"):
call(bad)


# ---------------------------------------------------------------------------
# Band sums outside {0, 1}
# ---------------------------------------------------------------------------

class TestBandSums:
def test_each_consumer_rejects_sum_above_one(self):
mfd = _valid_mfd()
r, c = _interior_flow_cell(mfd)
bad = _corrupt(mfd, lambda v: v.__setitem__((slice(None), r, c), 0.5))
# 8 bands * 0.5 = 4.0
for fname, call in _callers(mfd):
with pytest.raises(ValueError, match="sum"):
call(bad)

def test_sink_cell_sum_zero_is_accepted(self):
# The bowl center is a pit: all 8 bands are 0.0 (sum 0.0). This
# must remain valid.
mfd = _valid_mfd()
vals = mfd.values
nan_count = np.isnan(vals).sum(axis=0)
sums = np.nansum(vals, axis=0)
has_sink = np.any((nan_count == 0) & (np.abs(sums) <= 1e-6))
assert has_sink, "test grid has no pit/sink cell"
flow_accumulation_mfd(mfd) # should not raise


# ---------------------------------------------------------------------------
# Partial-NaN band pattern
# ---------------------------------------------------------------------------

class TestPartialNaN:
def test_each_consumer_rejects_partial_nan(self):
mfd = _valid_mfd()
r, c = _interior_flow_cell(mfd)

def mutate(v):
# NaN one band, leave the rest finite -> partial NaN
v[0, r, c] = np.nan

bad = _corrupt(mfd, mutate)
for fname, call in _callers(mfd):
with pytest.raises(ValueError, match="partial-NaN"):
call(bad)

def test_all_nan_cell_is_accepted(self):
# Edge cells from flow_direction_mfd are all-NaN; that is valid.
mfd = _valid_mfd()
assert np.isnan(mfd.values[:, 0, 0]).all()
flow_accumulation_mfd(mfd) # should not raise


# ---------------------------------------------------------------------------
# Dask skips eager value validation (laziness preserved)
# ---------------------------------------------------------------------------

class TestDaskSkipsValueCheck:
def test_dask_invalid_values_not_validated_eagerly(self):
dask = pytest.importorskip('dask.array')
mfd = _valid_mfd()
r, c = _interior_flow_cell(mfd)
bad = _corrupt(mfd, lambda v: v.__setitem__((0, r, c), -5.0))
dask_bad = xr.DataArray(
dask.from_array(bad.values, chunks=(8, 5, 5)),
dims=mfd.dims, coords=mfd.coords,
)
# Should not raise at validation time: dask value checks are
# deferred (laziness preserved), not run eagerly here.
out = flow_accumulation_mfd(dask_bad)
assert isinstance(out.data, dask.Array)
4 changes: 4 additions & 0 deletions xrspatial/hydro/watershed_mfd.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ class cupy: # type: ignore[no-redef]
from xrspatial.hydro._boundary_store import BoundaryStore
from xrspatial.utils import (
_validate_matching_shape,
_validate_mfd_fractions,
_validate_raster,
has_cuda_and_cupy,
is_cupy_array,
Expand Down Expand Up @@ -685,6 +686,9 @@ def watershed_mfd(flow_dir_mfd: xr.DataArray,
raise ValueError(
f"flow_dir_mfd must have shape (8, H, W), got {data.shape}")

_validate_mfd_fractions(data, func_name='watershed_mfd',
name='flow_dir_mfd')

_, H, W = data.shape

_validate_matching_shape(
Expand Down
80 changes: 80 additions & 0 deletions xrspatial/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,86 @@ def _validate_scalar(
)


def _validate_mfd_fractions(data, *, func_name: str, name: str = 'fractions',
atol: float = 1e-6):
"""Validate the VALUES of a (8, H, W) MFD fraction grid.

The public MFD functions document that each cell's 8 fraction
bands lie in ``[0, 1]`` and sum to either 1.0 (flow) or 0.0
(pit/flat/sink), with all-NaN bands at edges and nodata cells.
This checks those value invariants and raises a clear error when
the input violates them, before any hydrology math runs.

Three checks per cell:

* No negative fractions.
* The band sum is 1.0 or 0.0 within *atol*.
* NaN bands are all-or-nothing: either all 8 directions are NaN
(edge/nodata) or none are. A partially-NaN cell is rejected.

Only numpy and cupy (in-memory) arrays are validated. Dask arrays
are skipped so validation does not force computation; lazy
validation is handled separately. The shape is assumed to already
be ``(8, H, W)`` (callers check that first).

Parameters
----------
data : numpy.ndarray, cupy.ndarray, or dask.array.Array
The fraction grid (``DataArray.data``).
func_name : str
Name of the calling function (for error messages).
name : str
Parameter name (for error messages).
atol : float
Absolute tolerance for the band-sum check.

Raises
------
ValueError
If any cell has a negative fraction, a band sum that is
neither ~1.0 nor ~0.0, or a partial-NaN band pattern.
"""
if is_cupy_array(data):
xp = cupy
elif isinstance(data, np.ndarray):
xp = np
else:
# Dask (numpy- or cupy-backed) or other lazy types: skip value
# validation so we do not trigger computation.
return

prefix = f"{func_name}(): `{name}`"

nan_count = xp.isnan(data).sum(axis=0)
# Partial NaN: some but not all of the 8 bands are NaN.
if bool(((nan_count > 0) & (nan_count < 8)).any()):
raise ValueError(
f"{prefix} has cells with a partial-NaN band pattern. Each "
f"cell must have all 8 direction bands NaN (edge/nodata) or "
f"none of them NaN."
)

# NaN < 0 is False, so NaN cells never trip this (no copy needed).
if bool((data < 0).any()):
raise ValueError(
f"{prefix} contains negative flow fractions. Fractions must "
f"be in [0, 1]."
)

# Per-cell band sums, treating NaN bands as 0 so all-NaN cells sum
# to 0.0 and pass the sink check.
sums = xp.nansum(data, axis=0)
valid_cell = nan_count == 0
bad_sum = valid_cell & ~(
(xp.abs(sums - 1.0) <= atol) | (xp.abs(sums) <= atol)
)
if bool(bad_sum.any()):
raise ValueError(
f"{prefix} has cells whose flow fractions do not sum to 1.0 "
f"(flow) or 0.0 (pit/flat/sink) within tolerance {atol}."
)


def _boundary_to_dask(boundary, is_cupy=False):
"""Convert a boundary mode string to the value expected by
``dask.array.map_overlap``'s *boundary* parameter."""
Expand Down
Loading