diff --git a/src/spikeinterface/core/sortinganalyzer.py b/src/spikeinterface/core/sortinganalyzer.py index b5885598fe..1588de85bb 100644 --- a/src/spikeinterface/core/sortinganalyzer.py +++ b/src/spikeinterface/core/sortinganalyzer.py @@ -217,7 +217,9 @@ def create_sorting_analyzer( return sorting_analyzer -def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer": +def load_sorting_analyzer( + folder, load_extensions=True, format="auto", backend_options=None, lazy=False +) -> "SortingAnalyzer": """ Load a SortingAnalyzer object from disk. @@ -245,7 +247,9 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o The loaded SortingAnalyzer """ - return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, backend_options=backend_options) + return SortingAnalyzer.load( + folder, load_extensions=load_extensions, format=format, backend_options=backend_options, lazy=lazy + ) class SortingAnalyzer: @@ -279,6 +283,7 @@ def __init__( sparsity: ChannelSparsity | None = None, return_in_uV: bool = True, backend_options: dict | None = None, + lazy: bool = False, ): # very fast init because checks are done in load and create self.sorting = sorting @@ -304,6 +309,9 @@ def __init__( # (additional saving options for creating and saving datasets, e.g. compression/filters for zarr) self._backend_options = {} if backend_options is None else backend_options + # the lazy flag is used to load the extensions in a lazy way (only when needed) + self._lazy = lazy + # extensions are not loaded at init self.extensions = dict() @@ -407,7 +415,7 @@ def create( return sorting_analyzer @classmethod - def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None): + def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None, lazy=False): """ Load folder or zarr. The recording can be given if the recording location has changed. @@ -422,14 +430,14 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe if format == "binary_folder": sorting_analyzer = SortingAnalyzer.load_from_binary_folder( - folder, recording=recording, backend_options=backend_options + folder, recording=recording, backend_options=backend_options, lazy=lazy ) elif format == "zarr": sorting_analyzer = SortingAnalyzer.load_from_zarr( - folder, recording=recording, backend_options=backend_options + folder, recording=recording, backend_options=backend_options, lazy=lazy ) - if not is_path_remote(str(folder)): + if not is_path_remote(str(folder)) and not lazy: if load_extensions: sorting_analyzer.load_all_saved_extension() @@ -532,15 +540,24 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options) @classmethod - def load_from_binary_folder(cls, folder, recording=None, backend_options=None): + def load_from_binary_folder(cls, folder, recording=None, backend_options=None, lazy=False): from .loading import load folder = Path(folder) assert folder.is_dir(), f"This folder does not exists {folder}" # load internal sorting copy in memory + if lazy: + numpy_folder_kwargs = dict(mmap_mode="r") + copy_spike_vector = False + else: + numpy_folder_kwargs = dict() + copy_spike_vector = True + sorting = NumpySorting.from_sorting( - NumpyFolderSorting(folder / "sorting"), with_metadata=True, copy_spike_vector=True + NumpyFolderSorting(folder / "sorting", **numpy_folder_kwargs), + with_metadata=True, + copy_spike_vector=copy_spike_vector, ) # Try to load the recording if not provided @@ -601,6 +618,7 @@ def load_from_binary_folder(cls, folder, recording=None, backend_options=None): sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, + lazy=lazy, ) sorting_analyzer.folder = folder @@ -698,7 +716,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options) @classmethod - def load_from_zarr(cls, folder, recording=None, backend_options=None): + def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False): import zarr from .loading import load @@ -721,11 +739,22 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): "Please consider re-generating the SortingAnalyzer object." ) - # load internal sorting in memory + if lazy: + copy_spike_vector = False + lazy_spike_vector = True + else: + copy_spike_vector = True + lazy_spike_vector = False + sorting = NumpySorting.from_sorting( - ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options), + ZarrSortingExtractor( + folder, + zarr_group="sorting", + storage_options=storage_options, + lazy_spike_vector=lazy_spike_vector, + ), with_metadata=True, - copy_spike_vector=True, + copy_spike_vector=copy_spike_vector, ) # load recording if possible @@ -770,6 +799,7 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None): sparsity=sparsity, return_in_uV=return_in_uV, backend_options=backend_options, + lazy=lazy, ) sorting_analyzer.folder = folder @@ -988,6 +1018,11 @@ def _save_or_select_or_merge_or_split( new_sorting_analyzer : SortingAnalyzer The newly created SortingAnalyzer object. """ + if self._lazy: + raise ValueError( + "Cannot save, select, merge or split units when the SortingAnalyzer is lazy. " + "Please load the SortingAnalyzer with lazy=False." + ) if self.has_recording(): recording = self._recording elif self.has_temporary_recording(): @@ -1936,7 +1971,7 @@ def load_extension(self, extension_name: str): if extension_class is None: return None - extension_instance = extension_class.load(self) + extension_instance = extension_class.load(self, lazy=self._lazy) self.extensions[extension_name] = extension_instance @@ -2414,20 +2449,20 @@ def _get_zarr_extension_group(self, mode="r+"): return extension_group @classmethod - def load(cls, sorting_analyzer): + def load(cls, sorting_analyzer, lazy=False): ext = cls(sorting_analyzer) ext.load_params() ext.load_run_info() if ext.run_info is not None: if ext.run_info["run_completed"]: - ext.load_data() + ext.load_data(lazy=lazy) if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() if len(ext.data) > 0: return ext else: # this is for back-compatibility of old analyzers - ext.load_data() + ext.load_data(lazy=lazy) if cls.need_backward_compatibility_on_load: ext._handle_backward_compatibility_on_load() if len(ext.data) > 0: @@ -2527,7 +2562,7 @@ def load_params(self): self.params = params - def load_data(self): + def load_data(self, lazy=False): ext_data = None if self.format == "binary_folder": extension_folder = self._get_binary_extension_folder() @@ -2547,10 +2582,10 @@ def load_data(self): ext_data = json.load(f) elif ext_data_file.suffix == ".npy": # The lazy loading of an extension is complicated because if we compute again - # and have a link to the old buffer on windows then it fails - # ext_data = np.load(ext_data_file, mmap_mode="r") - # so we go back to full loading - ext_data = np.load(ext_data_file) + # and have a link to the old buffer on windows then it fails. + # So, by default, we use full loading, but lazy can be requested on demand. + kwargs = dict(mmap_mode="r") if lazy else dict() + ext_data = np.load(ext_data_file, **kwargs) elif ext_data_file.suffix == ".csv": import pandas as pd @@ -2586,8 +2621,7 @@ def load_data(self): elif "object" in ext_data_.attrs: ext_data = ext_data_[0] else: - # this load in memmory - ext_data = np.array(ext_data_) + ext_data = ext_data_ if lazy else np.array(ext_data_[:]) self.set_data(ext_data_name, ext_data) if len(self.data) == 0: diff --git a/src/spikeinterface/core/sortingfolder.py b/src/spikeinterface/core/sortingfolder.py index c0d66393d2..2dba9d4465 100644 --- a/src/spikeinterface/core/sortingfolder.py +++ b/src/spikeinterface/core/sortingfolder.py @@ -24,7 +24,7 @@ class NumpyFolderSorting(BaseSorting): mode = "folder" name = "NumpyFolder" - def __init__(self, folder_path): + def __init__(self, folder_path, mmap_mode=None): folder_path = Path(folder_path) with open(folder_path / "numpysorting_info.json", "r") as f: @@ -36,7 +36,7 @@ def __init__(self, folder_path): BaseSorting.__init__(self, sampling_frequency, unit_ids) - self.spikes = np.load(folder_path / "spikes.npy") + self.spikes = np.load(folder_path / "spikes.npy", mmap_mode=mmap_mode) for segment_index in range(num_segments): self.add_sorting_segment(SpikeVectorSortingSegment(self.spikes, segment_index, unit_ids)) @@ -47,7 +47,7 @@ def __init__(self, folder_path): folder_metadata = folder_path self.load_metadata_from_folder(folder_metadata) - self._kwargs = dict(folder_path=str(folder_path.absolute())) + self._kwargs = dict(folder_path=str(folder_path.absolute()), mmap_mode=mmap_mode) @staticmethod def write_sorting(sorting, save_path): diff --git a/src/spikeinterface/core/tests/test_sortinganalyzer.py b/src/spikeinterface/core/tests/test_sortinganalyzer.py index a9bd71b5c0..e0411bc9cd 100644 --- a/src/spikeinterface/core/tests/test_sortinganalyzer.py +++ b/src/spikeinterface/core/tests/test_sortinganalyzer.py @@ -119,7 +119,7 @@ def test_SortingAnalyzer_binary_folder(tmp_path, dataset): assert "number" in sorting_analyzer.sorting.get_property_keys() sorting_analyzer_reloded = load_sorting_analyzer(folder, format="auto") assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys() - assert "number" in sorting_analyzer.sorting.get_property_keys() + assert "number" in sorting_analyzer_reloded.sorting.get_property_keys() def test_SortingAnalyzer_zarr(tmp_path, dataset): @@ -201,7 +201,7 @@ def test_SortingAnalyzer_zarr(tmp_path, dataset): assert "number" in sorting_analyzer.sorting.get_property_keys() sorting_analyzer_reloded = load_sorting_analyzer(sorting_analyzer.folder, format="auto") assert "quality" in sorting_analyzer_reloded.sorting.get_property_keys() - assert "number" in sorting_analyzer.sorting.get_property_keys() + assert "number" in sorting_analyzer_reloded.sorting.get_property_keys() def test_create_by_dict(): @@ -325,6 +325,67 @@ def test_SortingAnalyzer_interleaved_probegroup(dataset): assert np.array_equal(recording.get_channel_locations(), sorting_analyzer.get_channel_locations()) +def test_load_in_lazy_mode_binary(tmp_path, dataset): + recording, sorting = dataset + + folder = tmp_path / "test_SortingAnalyzer_binary_folder" + if folder.exists(): + shutil.rmtree(folder) + + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="binary_folder", folder=folder, sparse=False, sparsity=None + ) + + sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) + # load in lazy mode and check that spike vector and extension data are memmap + sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True) + + assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), np.memmap) + + template_ext = sorting_analyzer_lazy.get_extension("templates") + template_data = template_ext.data + for key, value in template_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, np.memmap) + spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes") + spike_amplitudes_data = spike_amplitudes_ext.data + for key, value in spike_amplitudes_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, np.memmap) + + +def test_load_in_lazy_mode_zarr(tmp_path, dataset): + import zarr + from spikeinterface.core.zarrextractors import ZarrSpikeVector + + recording, sorting = dataset + + folder = tmp_path / "test_SortingAnalyzer_zarr_folder.zarr" + if folder.exists(): + shutil.rmtree(folder) + + sorting_analyzer = create_sorting_analyzer( + sorting, recording, format="zarr", folder=folder, sparse=False, sparsity=None + ) + + sorting_analyzer.compute(["random_spikes", "templates", "spike_amplitudes"]) + # load in lazy mode and check that spikevector is ZarrSpikeVector andextension data are zarr arrays + sorting_analyzer_lazy = load_sorting_analyzer(folder, format="auto", lazy=True) + + assert isinstance(sorting_analyzer_lazy.sorting.to_spike_vector(), ZarrSpikeVector) + + template_ext = sorting_analyzer_lazy.get_extension("templates") + template_data = template_ext.data + for key, value in template_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, zarr.Array) + spike_amplitudes_ext = sorting_analyzer_lazy.get_extension("spike_amplitudes") + spike_amplitudes_data = spike_amplitudes_ext.data + for key, value in spike_amplitudes_data.items(): + if isinstance(value, np.ndarray): + assert isinstance(value, zarr.Array) + + def _check_sorting_analyzers(sorting_analyzer, original_sorting, cache_folder): register_result_extension(DummyAnalyzerExtension) diff --git a/src/spikeinterface/core/zarrextractors.py b/src/spikeinterface/core/zarrextractors.py index bbc797c693..a832129fe1 100644 --- a/src/spikeinterface/core/zarrextractors.py +++ b/src/spikeinterface/core/zarrextractors.py @@ -241,6 +241,112 @@ def get_traces( return traces +class _ZarrSegmentIndex: + """Lazy segment_index array derived from segment_slices stored in zarr.""" + + def __init__(self, segment_slices: np.ndarray, n: int): + self._segment_slices = segment_slices + self._n = n + + def __len__(self) -> int: + return self._n + + def __array__(self, dtype=None): + arr = np.empty(self._n, dtype="int64") + for seg_idx, (s0, s1) in enumerate(self._segment_slices): + arr[s0:s1] = seg_idx + return arr if dtype is None else arr.astype(dtype) + + def __getitem__(self, key): + return np.asarray(self)[key] + + def __eq__(self, other): + return np.asarray(self) == other + + +class ZarrSpikeVector: + """ + Virtual structured spike vector backed by zarr arrays. + + Mimics a memmap-backed numpy structured array with fields + (sample_index, unit_index, segment_index) without loading any data + at construction time. Data is read from zarr lazily: + + * Field access (``spikes["sample_index"]``) returns the zarr array + (or a lazy segment-index object). + * Slice access (``spikes[s0:s1]``) materialises only that slice. + * ``np.asarray(spikes)`` materialises the full array. + + The zarr arrays are assumed to be stored in sorted order + (segment_index ASC, sample_index ASC, unit_index ASC), which is the + ordering guaranteed by :func:`add_sorting_to_zarr_group`. + """ + + def __init__(self, spikes_group, segment_slices: np.ndarray): + self._sample_index = spikes_group["sample_index"] + self._unit_index = spikes_group["unit_index"] + self._segment_slices = np.asarray(segment_slices, dtype="int64") + self._n = len(self._sample_index) + self.dtype = np.dtype(minimum_spike_dtype) + + @property + def size(self) -> int: + return self._n + + def __len__(self) -> int: + return self._n + + def __getitem__(self, key): + if isinstance(key, str): + if key == "sample_index": + return self._sample_index + elif key == "unit_index": + return self._unit_index + elif key == "segment_index": + return _ZarrSegmentIndex(self._segment_slices, self._n) + else: + raise KeyError(f"ZarrSpikeVector has no field {key!r}") + + if isinstance(key, (int, np.integer)): + idx = int(key) + if idx < 0: + idx += self._n + result = np.empty(1, dtype=self.dtype) + result["sample_index"][0] = self._sample_index[idx] + result["unit_index"][0] = self._unit_index[idx] + result["segment_index"][0] = int(np.searchsorted(self._segment_slices[:, 0], idx, side="right")) - 1 + return result[0] + + if isinstance(key, slice): + start, stop, step = key.indices(self._n) + n = len(range(start, stop, step)) + result = np.empty(n, dtype=self.dtype) + result["sample_index"] = self._sample_index[start:stop:step] + result["unit_index"] = self._unit_index[start:stop:step] + if step == 1: + seg_index = np.empty(n, dtype="int64") + for seg_idx, (s0, s1) in enumerate(self._segment_slices): + lo = max(start, int(s0)) - start + hi = min(stop, int(s1)) - start + if hi > lo: + seg_index[lo:hi] = seg_idx + result["segment_index"] = seg_index + else: + result["segment_index"] = _ZarrSegmentIndex(self._segment_slices, self._n)[start:stop:step] + return result + + # fallback for fancy/boolean indexing: materialise then index + return np.asarray(self)[key] + + def __array__(self, dtype=None): + arr = np.empty(self._n, dtype=self.dtype) + arr["sample_index"] = self._sample_index[:] + arr["unit_index"] = self._unit_index[:] + for seg_idx, (s0, s1) in enumerate(self._segment_slices): + arr["segment_index"][s0:s1] = seg_idx + return arr if dtype is None else arr.astype(dtype) + + class ZarrSortingExtractor(BaseSorting): """ SortingExtractor for a zarr format @@ -257,13 +363,23 @@ class ZarrSortingExtractor(BaseSorting): Storage options for zarr `store`. E.g., if "s3://" or "gcs://" they can provide authentication methods, etc. zarr_group : str or None, default: None Optional zarr group path to load the sorting from. This can be used when the sorting is not stored at the root, but in sub group. + lazy_spike_vector : bool, default: False + If True, the spike vector is loaded lazily. This can be useful for large sortings with many spikes. + If False, the spike vector is loaded in memory. Default: False + Returns ------- sorting : ZarrSortingExtractor The sorting Extractor """ - def __init__(self, folder_path: Path | str, storage_options: dict | None = None, zarr_group: str | None = None): + def __init__( + self, + folder_path: Path | str, + storage_options: dict | None = None, + zarr_group: str | None = None, + lazy_spike_vector: bool = False, + ): folder_path, folder_path_kwarg = resolve_zarr_path(folder_path) @@ -289,13 +405,21 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, BaseSorting.__init__(self, sampling_frequency, unit_ids) - spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) - spikes["sample_index"] = spikes_group["sample_index"][:] - spikes["unit_index"] = spikes_group["unit_index"][:] - for i, (start, end) in enumerate(segment_slices_list): - spikes["segment_index"][start:end] = i - spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + if lazy_spike_vector: + spikes = ZarrSpikeVector(spikes_group, segment_slices_list) + else: + # Materialize the spike vector in memory and sort it by (segment_index, sample_index, unit_index) + spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype) + spikes["sample_index"] = spikes_group["sample_index"][:] + spikes["unit_index"] = spikes_group["unit_index"][:] + for i, (start, end) in enumerate(segment_slices_list): + spikes["segment_index"][start:end] = i + spikes = spikes[np.lexsort((spikes["unit_index"], spikes["sample_index"], spikes["segment_index"]))] + self._cached_spike_vector = spikes + # pre-populate segment slices so _get_spike_vector_segment_slices() never + # needs to materialise the full segment_index array + self._cached_spike_vector_segment_slices = np.asarray(segment_slices_list, dtype="int64") for segment_index in range(num_segments): soring_segment = SpikeVectorSortingSegment(spikes, segment_index, unit_ids) @@ -313,7 +437,12 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None, if annotations is not None: self.annotate(**annotations) - self._kwargs = {"folder_path": folder_path_kwarg, "storage_options": storage_options, "zarr_group": zarr_group} + self._kwargs = { + "folder_path": folder_path_kwarg, + "storage_options": storage_options, + "zarr_group": zarr_group, + "lazy_spike_vector": lazy_spike_vector, + } @staticmethod def write_sorting(sorting: BaseSorting, folder_path: str | Path, storage_options: dict | None = None, **kwargs):