Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
65 changes: 65 additions & 0 deletions uxarray/core/dataarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from uxarray.core.zonal import (
_compute_conservative_zonal_mean_bands,
_compute_non_conservative_zonal_mean,
_compute_zonal_anomaly,
)
from uxarray.cross_sections import UxDataArrayCrossSectionAccessor
from uxarray.formatting_html import array_repr
Expand Down Expand Up @@ -767,6 +768,70 @@ def zonal_average(self, lat=(-90, 90, 10), conservative: bool = False, **kwargs)
"""Alias of zonal_mean; prefer `zonal_mean` for primary API."""
return self.zonal_mean(lat=lat, conservative=conservative, **kwargs)

def zonal_anomaly(self, lat=(-90, 90, 10), conservative: bool = False):
"""Compute the zonal anomaly: each face value minus the mean of its latitude band.

Returns a new ``UxDataArray`` with the same dimensions as the input,
where each face holds its original value minus the zonal mean of the
latitude band it belongs to.

Parameters
----------
lat : tuple or array-like, default=(-90, 90, 10)
Latitude band specification:
- tuple (start, end, step): band edges via np.linspace(start, end, n)
- array-like: explicit band edges in degrees
conservative : bool, default=False
If True, uses area-weighted band means and blends across bands for
faces that straddle a band boundary, reusing the face-band weight
matrix computed for zonal_mean so no geometry is duplicated.
If False, assigns each face to a band by its centroid latitude.

Returns
-------
UxDataArray
Same dimensions as input with per-face band mean subtracted.

Examples
--------
>>> uxds["var"].zonal_anomaly()
>>> uxds["var"].zonal_anomaly(lat=(-60, 60, 5), conservative=True)
"""
if not self._face_centered():
raise ValueError(
"Zonal anomaly is only supported for face-centered data variables."
)

if isinstance(lat, tuple):
start, end, step = lat
if step <= 0:
raise ValueError("Step size must be positive.")
num_points = int(round((end - start) / step)) + 1
edges = np.linspace(start, end, num_points)
edges = np.clip(edges, -90, 90)
elif isinstance(lat, (list, np.ndarray)):
edges = np.asarray(lat, dtype=float)
else:
raise ValueError(
"Invalid value for 'lat'. Must be a tuple (start, end, step) or array-like band edges."
)

if edges.ndim != 1 or edges.size < 2:
raise ValueError("Band edges must be 1D with at least two values.")

res = _compute_zonal_anomaly(self, edges, conservative=conservative)

return UxDataArray(
res,
dims=self.dims,
coords=self.coords,
name=self.name + "_zonal_anomaly"
if self.name is not None
else "zonal_anomaly",
attrs={"zonal_anomaly": True, "conservative": conservative},
uxgrid=self.uxgrid,
)

def azimuthal_mean(
self,
center_coord,
Expand Down
222 changes: 151 additions & 71 deletions uxarray/core/zonal.py
Original file line number Diff line number Diff line change
Expand Up @@ -225,31 +225,25 @@ def _compute_band_overlap_area(
return area


def _compute_conservative_zonal_mean_bands(uxda, bands):
"""
Compute conservative zonal mean over latitude bands.
def _compute_face_band_weights(uxgrid, bands):
"""Compute overlap area between every face and every latitude band.

Uses get_faces_between_latitudes to optimize computation by avoiding
overlap area calculations for fully contained faces.
Shared geometry kernel used by both zonal_mean and zonal_anomaly so the
expensive intersection calculations are never duplicated.

Parameters
----------
uxda : UxDataArray
The data array to compute zonal means for
uxgrid : Grid
bands : array-like
Latitude band edges in degrees
Latitude band edges in degrees, shape (n_bands + 1,)

Returns
-------
result : array
Zonal means for each band
W : ndarray, shape (n_face, n_bands)
W[f, b] is the overlap area between face f and band b.
Fully-contained faces carry their full face area; partially-overlapping
faces carry the exact intersection area.
"""
import dask.array as da

uxgrid = uxda.uxgrid
face_axis = uxda.get_axis_num("n_face")

# Pre-compute face properties
faces_edge_nodes_xyz = _get_cartesian_face_edge_nodes_array(
uxgrid.face_node_connectivity.values,
uxgrid.n_face,
Expand All @@ -263,80 +257,166 @@ def _compute_conservative_zonal_mean_bands(uxda, bands):
face_areas = uxgrid.face_areas.values

bands = np.asarray(bands, dtype=float)
if bands.ndim != 1 or bands.size < 2:
raise ValueError("bands must be 1D with at least two edges")

nb = bands.size - 1

# Initialize result array
shape = list(uxda.shape)
shape[face_axis] = nb
if isinstance(uxda.data, da.Array):
result = da.zeros(shape, dtype=uxda.dtype)
else:
result = np.zeros(shape, dtype=uxda.dtype)
W = np.zeros((uxgrid.n_face, nb), dtype=float)

for bi in range(nb):
lat0 = float(np.clip(bands[bi], -90.0, 90.0))
lat1 = float(np.clip(bands[bi + 1], -90.0, 90.0))

# Ensure lat0 <= lat1
if lat0 > lat1:
lat0, lat1 = lat1, lat0

z0 = np.sin(np.deg2rad(lat0))
z1 = np.sin(np.deg2rad(lat1))
zmin, zmax = (z0, z1) if z0 <= z1 else (z1, z0)

# Step 1: Get fully contained faces
fully_contained_faces = uxgrid.get_faces_between_latitudes((lat0, lat1))

# Step 2: Get all overlapping faces (including partial)
fully_contained = uxgrid.get_faces_between_latitudes((lat0, lat1))
mask = ~((face_bounds_lat[:, 1] < lat0) | (face_bounds_lat[:, 0] > lat1))
all_overlapping_faces = np.nonzero(mask)[0]
all_overlapping = np.nonzero(mask)[0]

if all_overlapping_faces.size == 0:
# No faces in this band
idx = [slice(None)] * result.ndim
idx[face_axis] = bi
result[tuple(idx)] = np.nan
if all_overlapping.size == 0:
continue

# Step 3: Partition faces into fully contained vs partially overlapping
is_fully_contained = np.isin(all_overlapping_faces, fully_contained_faces)
partially_overlapping_faces = all_overlapping_faces[~is_fully_contained]

# Step 4: Compute weights
all_weights = np.zeros(all_overlapping_faces.size, dtype=float)

# For fully contained faces, use their full area
if fully_contained_faces.size > 0:
fully_contained_indices = np.where(is_fully_contained)[0]
all_weights[fully_contained_indices] = face_areas[fully_contained_faces]

# For partially overlapping faces, compute fractional area
if partially_overlapping_faces.size > 0:
partial_indices = np.where(~is_fully_contained)[0]
for i, face_idx in enumerate(partially_overlapping_faces):
nedge = n_nodes_per_face[face_idx]
face_edges = faces_edge_nodes_xyz[face_idx, :nedge]
overlap_area = _compute_band_overlap_area(face_edges, zmin, zmax)
all_weights[partial_indices[i]] = overlap_area

# Step 5: Compute weighted average
data_slice = uxda.isel(n_face=all_overlapping_faces, ignore_grid=True).data
total_weight = all_weights.sum()

if total_weight == 0.0:
weighted = np.nan * data_slice[..., 0]
else:
w_shape = [1] * data_slice.ndim
w_shape[face_axis] = all_weights.size
w_reshaped = all_weights.reshape(w_shape)
weighted = (data_slice * w_reshaped).sum(axis=face_axis) / total_weight
is_fully_contained = np.isin(all_overlapping, fully_contained)

fc = all_overlapping[is_fully_contained]
W[fc, bi] = face_areas[fc]

for f in all_overlapping[~is_fully_contained]:
nedge = n_nodes_per_face[f]
W[f, bi] = _compute_band_overlap_area(
faces_edge_nodes_xyz[f, :nedge], zmin, zmax
)

return W


def _compute_conservative_zonal_mean_bands(uxda, bands):
"""Compute conservative zonal mean over latitude bands.

Parameters
----------
uxda : UxDataArray
bands : array-like
Latitude band edges in degrees

Returns
-------
result : array
Zonal means for each band, with n_face axis replaced by n_bands
"""
import dask.array as da

bands = np.asarray(bands, dtype=float)
if bands.ndim != 1 or bands.size < 2:
raise ValueError("bands must be 1D with at least two edges")

W = _compute_face_band_weights(uxda.uxgrid, bands) # (n_face, n_bands)
nb = W.shape[1]
face_axis = uxda.get_axis_num("n_face")

shape = list(uxda.shape)
shape[face_axis] = nb
if isinstance(uxda.data, da.Array):
result = da.full(shape, np.nan, dtype=float)
else:
result = np.full(shape, np.nan, dtype=float)

for bi in range(nb):
overlapping = np.nonzero(W[:, bi] > 0)[0]
if overlapping.size == 0:
continue

w = W[overlapping, bi]
total = w.sum()
if total == 0.0:
continue

data_slice = uxda.isel(n_face=overlapping, ignore_grid=True).data
w_shape = [1] * data_slice.ndim
w_shape[face_axis] = w.size
weighted = (data_slice * w.reshape(w_shape)).sum(axis=face_axis) / total

idx = [slice(None)] * result.ndim
idx[face_axis] = bi
result[tuple(idx)] = weighted

return result


def _compute_zonal_anomaly(uxda, bands, conservative=False):
"""Compute zonal anomaly: each face value minus the mean of its latitude band.

Parameters
----------
uxda : UxDataArray
bands : array-like
Latitude band edges in degrees
conservative : bool
If True, uses area-weighted band means and blends across bands for
faces that straddle a boundary, reusing the same weight matrix as
zonal_mean so geometry is computed only once.
If False, assigns each face to a band by centroid latitude.

Returns
-------
ndarray
Same shape as uxda, with the per-face band mean subtracted.
"""
bands = np.asarray(bands, dtype=float)
face_axis = uxda.get_axis_num("n_face")
n_face = uxda.uxgrid.n_face
nb = bands.size - 1

if conservative:
# Single geometry pass shared with zonal_mean
W = _compute_face_band_weights(uxda.uxgrid, bands) # (n_face, n_bands)

# Band means
band_means = np.full(nb, np.nan)
for bi in range(nb):
overlapping = np.nonzero(W[:, bi] > 0)[0]
if overlapping.size == 0:
continue
w = W[overlapping, bi]
total = w.sum()
if total > 0:
vals = uxda.isel(n_face=overlapping, ignore_grid=True).values
band_means[bi] = (w * vals).sum() / total

# Map band means back to faces; straddling faces get area-weighted blend
face_totals = W.sum(axis=1)
valid = face_totals > 0
face_means = np.where(
valid,
np.where(
valid,
(
W * np.where(np.isnan(band_means), 0.0, band_means)[np.newaxis, :]
).sum(axis=1)
/ np.where(valid, face_totals, 1.0),
np.nan,
),
np.nan,
)
else:
# Centroid-based: fast, no intersection geometry needed
face_lats = uxda.uxgrid.face_lat.values
band_indices = np.clip(np.digitize(face_lats, bands) - 1, 0, nb - 1)

band_means = np.full(nb, np.nan)
for bi in range(nb):
mask = band_indices == bi
if mask.any():
band_means[bi] = float(
uxda.isel(
n_face=np.nonzero(mask)[0], ignore_grid=True
).values.mean()
)

face_means = band_means[band_indices]

# Broadcast face_means to match uxda shape (face axis may not be last)
shape = [1] * uxda.ndim
shape[face_axis] = n_face
return uxda.values - face_means.reshape(shape)
Loading