Skip to content

Commit 7eb2e61

Browse files
Merge pull request #369 from scverse/fix/assert-xenium-shapes-id
Xenium avoid calling `get_element_instances()`; fix deprecation warning
2 parents 28eacc9 + b7f2318 commit 7eb2e61

6 files changed

Lines changed: 64 additions & 166 deletions

File tree

docs/conf.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@
127127
# If building the documentation fails because of a missing link that is outside your control,
128128
# you can add an exception to this list.
129129
("py:class", "Path"),
130+
("py:class", "pathlib._local.Path"),
130131
("py:class", "AnnData"),
131132
("py:class", "SpatialData"),
132133
("py:func", "imageio.imread"), # maybe this can be fixed

pyproject.toml

Lines changed: 22 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,18 @@ requires = ["hatchling", "hatch-vcs"]
55

66
[project]
77
name = "spatialdata-io"
8-
dynamic= [
9-
"version" # allow version to be set by git tags
8+
dynamic = [
9+
"version" # allow version to be set by git tags
1010
]
1111
description = "SpatialData IO for common techs"
1212
readme = "README.md"
1313
requires-python = ">=3.11"
14-
license = {file = "LICENSE"}
14+
license = { file = "LICENSE" }
1515
authors = [
16-
{name = "scverse"},
16+
{ name = "scverse" },
1717
]
1818
maintainers = [
19-
{name = "scverse", email = "scverse@scverse.scverse"},
19+
{ name = "scverse", email = "scverse@scverse.scverse" },
2020
]
2121
urls.Documentation = "https://spatialdata-io.readthedocs.io/"
2222
urls.Source = "https://github.com/scverse/spatialdata-io"
@@ -26,7 +26,7 @@ dependencies = [
2626
"click",
2727
"numpy",
2828
"scanpy",
29-
"spatialdata>=0.2.6",
29+
"spatialdata>=0.7.3a0",
3030
"scikit-image",
3131
"h5py",
3232
"joblib",
@@ -46,7 +46,7 @@ dev = [
4646
"pre-commit"
4747
]
4848
doc = [
49-
"sphinx>=4.5",
49+
"sphinx>=4.5,<9",
5050
"sphinx-book-theme>=1.0.0",
5151
"myst-nb",
5252
"sphinxcontrib-bibtex>=1.0.0",
@@ -67,7 +67,7 @@ test = [
6767
# update: readthedocs doens't seem to try to install pre-releases even if when trying to install the pre optional-dependency. For
6868
# the moment, if needed, let's add the latest pre-release explicitly here.
6969
pre = [
70-
"spatialdata>=0.4.0rc0"
70+
"spatialdata>=0.7.3a0"
7171
]
7272

7373
[tool.coverage.run]
@@ -80,7 +80,7 @@ omit = [
8080
testpaths = ["tests"]
8181
xfail_strict = true
8282
addopts = [
83-
"--import-mode=importlib", # allow using test files with same name
83+
"--import-mode=importlib", # allow using test files with same name
8484
]
8585

8686
[tool.ruff]
@@ -95,19 +95,19 @@ exclude = [
9595
"setup.py",
9696
]
9797
lint.select = [
98-
"F", # Errors detected by Pyflakes
99-
"E", # Error detected by Pycodestyle
100-
"W", # Warning detected by Pycodestyle
101-
"I", # isort
102-
"D", # pydocstyle
103-
"B", # flake8-bugbear
104-
"TID", # flake8-tidy-imports
105-
"C4", # flake8-comprehensions
106-
"BLE", # flake8-blind-except
107-
"UP", # pyupgrade
108-
"RUF100", # Report unused noqa directives
109-
"TCH", # Typing imports
110-
"NPY", # Numpy specific rules
98+
"F", # Errors detected by Pyflakes
99+
"E", # Error detected by Pycodestyle
100+
"W", # Warning detected by Pycodestyle
101+
"I", # isort
102+
"D", # pydocstyle
103+
"B", # flake8-bugbear
104+
"TID", # flake8-tidy-imports
105+
"C4", # flake8-comprehensions
106+
"BLE", # flake8-blind-except
107+
"UP", # pyupgrade
108+
"RUF100", # Report unused noqa directives
109+
"TCH", # Typing imports
110+
"NPY", # Numpy specific rules
111111
# "PTH", # Use pathlib
112112
# "S" # Security
113113
]

src/spatialdata_io/readers/_utils/_image.py

Lines changed: 1 addition & 58 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77
from numpy.typing import NDArray
88
from spatialdata.models.models import Chunks_t
99

10-
__all__ = ["Chunks_t", "_compute_chunks", "_read_chunks", "normalize_chunks"]
10+
__all__ = ["Chunks_t", "_compute_chunks", "_read_chunks"]
1111

1212
_Y_IDX = 0
1313
"""Index of y coordinate in in chunk coordinate array format: (y, x, height, width)"""
@@ -143,60 +143,3 @@ def _read_chunks(
143143
for chunk_y in range(coords.shape[0])
144144
]
145145
return chunks
146-
147-
148-
def normalize_chunks(
149-
chunks: Chunks_t | None,
150-
axes: Sequence[str],
151-
) -> dict[str, int]:
152-
"""Normalize chunk specification to dict format.
153-
154-
This function converts various chunk formats to a dict mapping dimension names
155-
to chunk sizes. The dict format is preferred because it's explicit about which
156-
dimension gets which chunk size and is compatible with spatialdata.
157-
158-
Parameters
159-
----------
160-
chunks
161-
Chunk specification. Can be:
162-
- None: Uses DEFAULT_CHUNK_SIZE for all axes
163-
- int: Applied to all axes
164-
- tuple[int, ...]: Chunk sizes in order corresponding to axes
165-
- dict: Mapping of axis names to chunk sizes (validated against axes)
166-
axes
167-
Tuple of axis names that defines the expected dimensions (e.g., ('c', 'y', 'x')).
168-
169-
Returns
170-
-------
171-
dict[str, int]
172-
Dict mapping axis names to chunk sizes.
173-
174-
Raises
175-
------
176-
ValueError
177-
If chunks format is not supported or incompatible with axes.
178-
"""
179-
if chunks is None:
180-
return dict.fromkeys(axes, DEFAULT_CHUNK_SIZE)
181-
182-
if isinstance(chunks, int):
183-
return dict.fromkeys(axes, chunks)
184-
185-
if isinstance(chunks, Mapping):
186-
chunks_dict = dict(chunks)
187-
missing = set(axes) - set(chunks_dict.keys())
188-
if missing:
189-
raise ValueError(f"chunks dict missing keys for axes {missing}, got: {list(chunks_dict.keys())}")
190-
return {ax: chunks_dict[ax] for ax in axes}
191-
192-
if isinstance(chunks, tuple):
193-
if len(chunks) != len(axes):
194-
raise ValueError(f"chunks tuple length {len(chunks)} doesn't match axes {axes} (length {len(axes)})")
195-
if not all(isinstance(c, int) for c in chunks):
196-
raise ValueError(f"All elements in chunks tuple must be int, got: {chunks}")
197-
return dict(zip(axes, chunks, strict=True))
198-
199-
raise ValueError(f"Unsupported chunks type: {type(chunks)}. Expected int, tuple, dict, or None.")
200-
201-
202-
##

src/spatialdata_io/readers/generic.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import numpy as np
88
import tifffile
99
from dask_image.imread import imread
10-
from geopandas import GeoDataFrame
1110
from spatialdata._docs import docstring_parameter
1211
from spatialdata._logging import logger
1312
from spatialdata.models import Image2DModel, ShapesModel
@@ -23,10 +22,12 @@
2322
from xarray import DataArray
2423

2524

25+
from spatialdata.models.chunks_utils import normalize_chunks
26+
2627
from spatialdata_io.readers._utils._image import (
28+
DEFAULT_CHUNK_SIZE,
2729
_compute_chunks,
2830
_read_chunks,
29-
normalize_chunks,
3031
)
3132

3233
VALID_IMAGE_TYPES = [".tif", ".tiff", ".png", ".jpg", ".jpeg"]
@@ -179,7 +180,7 @@ def image(
179180
chunks: Chunks_t | None = None,
180181
scale_factors: Sequence[int] | None = None,
181182
) -> DataArray:
182-
"""Reads an image file and returns a parsed Image2D spatial element.
183+
"""Read an image file and returns a parsed Image2D spatial element.
183184
184185
Parameters
185186
----------
@@ -207,6 +208,8 @@ def image(
207208
# Map passed data axes to position of dimension
208209
axes_dim_mapping = {axes: ndim for ndim, axes in enumerate(data_axes)}
209210

211+
if chunks is None:
212+
chunks = DEFAULT_CHUNK_SIZE
210213
chunks_dict = normalize_chunks(chunks, axes=data_axes)
211214

212215
im = None

src/spatialdata_io/readers/xenium.py

Lines changed: 34 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,11 @@
2020
from dask.dataframe import read_parquet
2121
from dask_image.imread import imread
2222
from geopandas import GeoDataFrame
23-
from joblib import Parallel, delayed
2423
from pyarrow import Table
2524
from shapely import GeometryType, Polygon, from_ragged_array
2625
from spatialdata import SpatialData
2726
from spatialdata._core.query.relational_query import get_element_instances
27+
from spatialdata._logging import logger
2828
from spatialdata.models import (
2929
Image2DModel,
3030
Labels2DModel,
@@ -61,7 +61,7 @@ def xenium(
6161
*,
6262
cells_boundaries: bool = True,
6363
nucleus_boundaries: bool = True,
64-
cells_as_circles: bool | None = None,
64+
cells_as_circles: bool = False,
6565
cells_labels: bool = True,
6666
nucleus_labels: bool = True,
6767
transcripts: bool = True,
@@ -136,7 +136,7 @@ def xenium(
136136
137137
Notes
138138
-----
139-
Old versions. Until spatialdata-io v0.1.3post0: previously, `cells_as_circles` was `True` by default; the table was associated to the
139+
Old versions. Until spatialdata-io v0.6.0: `cells_as_circles` was `True` by default; the table was associated to the
140140
circles when `cells_as_circles` was `True`, and the table was associated to the polygons when `cells_as_circles`
141141
was `False`; the radii of the circles were computed form the nuclei instead of the cells.
142142
@@ -153,14 +153,6 @@ def xenium(
153153
... )
154154
>>> sdata.write("path/to/data.zarr")
155155
"""
156-
if cells_as_circles is None:
157-
cells_as_circles = True
158-
warnings.warn(
159-
"The default value of `cells_as_circles` will change to `False` in the next release. "
160-
"Please pass `True` explicitly to maintain the current behavior.",
161-
DeprecationWarning,
162-
stacklevel=3,
163-
)
164156
image_models_kwargs, labels_models_kwargs = _initialize_raster_models_kwargs(
165157
image_models_kwargs, labels_models_kwargs
166158
)
@@ -223,18 +215,16 @@ def xenium(
223215
# labels.
224216
if nucleus_labels:
225217
labels["nucleus_labels"], _ = _get_labels_and_indices_mapping(
226-
path,
227-
XeniumKeys.CELLS_ZARR,
228-
specs,
218+
path=path,
219+
specs=specs,
229220
mask_index=0,
230221
labels_name="nucleus_labels",
231222
labels_models_kwargs=labels_models_kwargs,
232223
)
233224
if cells_labels:
234225
labels["cell_labels"], cell_labels_indices_mapping = _get_labels_and_indices_mapping(
235-
path,
236-
XeniumKeys.CELLS_ZARR,
237-
specs,
226+
path=path,
227+
specs=specs,
238228
mask_index=1,
239229
labels_name="cell_labels",
240230
labels_models_kwargs=labels_models_kwargs,
@@ -360,8 +350,8 @@ def filter(self, record: logging.LogRecord) -> bool:
360350
return False
361351
return True
362352

363-
logger = tifffile.logger()
364-
logger.addFilter(IgnoreSpecificMessage())
353+
tf_logger = tifffile.logger()
354+
tf_logger.addFilter(IgnoreSpecificMessage())
365355
image_models_kwargs = dict(image_models_kwargs)
366356
assert "c_coords" not in image_models_kwargs, (
367357
"The channel names for the morphology focus images are handled internally"
@@ -374,7 +364,7 @@ def filter(self, record: logging.LogRecord) -> bool:
374364
image_models_kwargs,
375365
)
376366
del image_models_kwargs["c_coords"]
377-
logger.removeFilter(IgnoreSpecificMessage())
367+
tf_logger.removeFilter(IgnoreSpecificMessage())
378368

379369
if table is not None:
380370
tables["table"] = table
@@ -402,14 +392,16 @@ def filter(self, record: logging.LogRecord) -> bool:
402392
def _decode_cell_id_column(cell_id_column: pd.Series) -> pd.Series:
403393
if isinstance(cell_id_column.iloc[0], bytes):
404394
return cell_id_column.str.decode("utf-8")
395+
if not isinstance(cell_id_column.iloc[0], str):
396+
cell_id_column.index = cell_id_column.index.astype(str)
405397
return cell_id_column
406398

407399

408400
def _get_polygons(
409401
path: Path,
410402
file: str,
411403
specs: dict[str, Any],
412-
idx: ArrayLike | None = None,
404+
idx: pd.Series | None = None,
413405
) -> GeoDataFrame:
414406
# seems to be faster than pd.read_parquet
415407
df = pq.read_table(path / file).to_pandas()
@@ -448,7 +440,7 @@ def _get_polygons(
448440
if version is not None and version < packaging.version.parse("2.0.0"):
449441
assert idx is not None
450442
assert len(idx) == len(geo_df)
451-
assert index.equals(idx)
443+
assert np.array_equal(index.values, idx.values)
452444
else:
453445
if np.unique(geo_df.index).size != len(geo_df):
454446
warnings.warn(
@@ -464,7 +456,6 @@ def _get_polygons(
464456

465457
def _get_labels_and_indices_mapping(
466458
path: Path,
467-
file: str,
468459
specs: dict[str, Any],
469460
mask_index: int,
470461
labels_name: str,
@@ -493,36 +484,35 @@ def _get_labels_and_indices_mapping(
493484
cell_id, dataset_suffix = z["cell_id"][...].T
494485
cell_id_str = cell_id_str_from_prefix_suffix_uint32(cell_id, dataset_suffix)
495486

496-
# this information will probably be available in the `label_id` column for version > 2.0.0 (see public
497-
# release notes mentioned above)
498-
real_label_index = get_element_instances(labels).values
499-
500-
# background removal
501-
if real_label_index[0] == 0:
502-
real_label_index = real_label_index[1:]
503-
504487
if version < packaging.version.parse("2.0.0"):
505-
expected_label_index = z["seg_mask_value"][...]
506-
507-
if not np.array_equal(expected_label_index, real_label_index):
508-
raise ValueError(
509-
"The label indices from the labels differ from the ones from the input data. Please report "
510-
f"this issue. Real label indices: {real_label_index}, expected label indices: "
511-
f"{expected_label_index}."
512-
)
488+
label_index = z["seg_mask_value"][...]
513489
else:
514-
labels_positional_indices = z["polygon_sets"][f"{mask_index}"]["cell_index"][...]
515-
if not np.array_equal(labels_positional_indices, np.arange(len(labels_positional_indices))):
516-
raise ValueError(
517-
"The positional indices of the labels do not match the expected range. Please report this issue."
490+
# For v >= 2.0.0, seg_mask_value is no longer available in the zarr;
491+
# read label_id from the corresponding parquet boundary file instead
492+
boundaries_file = XeniumKeys.NUCLEUS_BOUNDARIES_FILE if mask_index == 0 else XeniumKeys.CELL_BOUNDARIES_FILE
493+
boundary_columns = pq.read_schema(path / boundaries_file).names
494+
if "label_id" in boundary_columns:
495+
boundary_df = pq.read_table(path / boundaries_file, columns=[XeniumKeys.CELL_ID, "label_id"]).to_pandas()
496+
unique_pairs = boundary_df.drop_duplicates(subset=[XeniumKeys.CELL_ID, "label_id"]).copy()
497+
unique_pairs[XeniumKeys.CELL_ID] = _decode_cell_id_column(unique_pairs[XeniumKeys.CELL_ID])
498+
cell_id_to_label_id = unique_pairs.set_index(XeniumKeys.CELL_ID)["label_id"]
499+
label_index = cell_id_to_label_id.loc[cell_id_str].values
500+
else:
501+
# fallback for dev versions around 2.0.0 that lack both seg_mask_value and label_id
502+
logger.warn(
503+
f"Could not find the labels ids from the metadata for version {version}. Using a fallback (slower) implementation."
518504
)
505+
label_index = get_element_instances(labels).values
506+
507+
if label_index[0] == 0:
508+
label_index = label_index[1:]
519509

520510
# labels_index is an uint32, so let's cast to np.int64 to avoid the risk of overflow on some systems
521511
indices_mapping = pd.DataFrame(
522512
{
523513
"region": labels_name,
524514
"cell_id": cell_id_str,
525-
"label_index": real_label_index.astype(np.int64),
515+
"label_index": label_index.astype(np.int64),
526516
}
527517
)
528518
# because AnnData converts the indices to str

0 commit comments

Comments
 (0)