Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 38 additions & 18 deletions src/spikeinterface/core/sortinganalyzer.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,9 @@ def create_sorting_analyzer(
return sorting_analyzer


def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_options=None) -> "SortingAnalyzer":
def load_sorting_analyzer(
folder, load_extensions=True, format="auto", backend_options=None, lazy=False
) -> "SortingAnalyzer":
"""
Load a SortingAnalyzer object from disk.

Expand Down Expand Up @@ -245,7 +247,9 @@ def load_sorting_analyzer(folder, load_extensions=True, format="auto", backend_o
The loaded SortingAnalyzer

"""
return SortingAnalyzer.load(folder, load_extensions=load_extensions, format=format, backend_options=backend_options)
return SortingAnalyzer.load(
folder, load_extensions=load_extensions, format=format, backend_options=backend_options, lazy=lazy
)


class SortingAnalyzer:
Expand Down Expand Up @@ -407,7 +411,7 @@ def create(
return sorting_analyzer

@classmethod
def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None):
def load(cls, folder, recording=None, load_extensions=True, format="auto", backend_options=None, lazy=False):
"""
Load folder or zarr.
The recording can be given if the recording location has changed.
Expand All @@ -422,11 +426,11 @@ def load(cls, folder, recording=None, load_extensions=True, format="auto", backe

if format == "binary_folder":
sorting_analyzer = SortingAnalyzer.load_from_binary_folder(
folder, recording=recording, backend_options=backend_options
folder, recording=recording, backend_options=backend_options, lazy=lazy
)
elif format == "zarr":
sorting_analyzer = SortingAnalyzer.load_from_zarr(
folder, recording=recording, backend_options=backend_options
folder, recording=recording, backend_options=backend_options, lazy=lazy
)

if not is_path_remote(str(folder)):
Expand Down Expand Up @@ -532,15 +536,23 @@ def create_binary_folder(cls, folder, sorting, recording, sparsity, return_in_uV
return cls.load_from_binary_folder(folder, recording=recording, backend_options=backend_options)

@classmethod
def load_from_binary_folder(cls, folder, recording=None, backend_options=None):
def load_from_binary_folder(cls, folder, recording=None, backend_options=None, lazy=False):
from .loading import load

folder = Path(folder)
assert folder.is_dir(), f"This folder does not exists {folder}"

# load internal sorting copy in memory
if lazy:
numpy_folder_kwargs = dict(mmap_mode="r")
copy_spike_vector = False
else:
numpy_folder_kwargs = dict()
copy_spike_vector = True
sorting = NumpySorting.from_sorting(
NumpyFolderSorting(folder / "sorting"), with_metadata=True, copy_spike_vector=True
NumpyFolderSorting(folder / "sorting", **numpy_folder_kwargs),
with_metadata=True,
copy_spike_vector=copy_spike_vector,
)

# Try to load the recording if not provided
Expand Down Expand Up @@ -698,7 +710,7 @@ def create_zarr(cls, folder, sorting, recording, sparsity, return_in_uV, rec_att
return cls.load_from_zarr(folder, recording=recording, backend_options=backend_options)

@classmethod
def load_from_zarr(cls, folder, recording=None, backend_options=None):
def load_from_zarr(cls, folder, recording=None, backend_options=None, lazy=False):
import zarr
from .loading import load

Expand All @@ -722,6 +734,8 @@ def load_from_zarr(cls, folder, recording=None, backend_options=None):
)

# load internal sorting in memory
if lazy:
copy_spike_vector = False
sorting = NumpySorting.from_sorting(
ZarrSortingExtractor(folder, zarr_group="sorting", storage_options=storage_options),
with_metadata=True,
Expand Down Expand Up @@ -1894,7 +1908,7 @@ def get_saved_extension_names(self):

return saved_extension_names

def get_extension(self, extension_name: str):
def get_extension(self, extension_name: str, lazy: bool = False):
"""
Get a AnalyzerExtension.
If not loaded then load is automatic.
Expand All @@ -1906,20 +1920,22 @@ def get_extension(self, extension_name: str):
return self.extensions[extension_name]

elif self.format != "memory" and self.has_extension(extension_name):
self.load_extension(extension_name)
self.load_extension(extension_name, lazy=lazy)
return self.extensions[extension_name]

else:
return None

def load_extension(self, extension_name: str):
def load_extension(self, extension_name: str, lazy: bool = False):
"""
Load an extension from a folder or zarr into the `ResultSorting.extensions` dict.

Parameters
----------
extension_name : str
The extension name.
lazy : bool, default: False
If True, array data are not loaded in memory, but kept as memmap/zarr arrays

Returns
-------
Expand All @@ -1936,7 +1952,7 @@ def load_extension(self, extension_name: str):
if extension_class is None:
return None

extension_instance = extension_class.load(self)
extension_instance = extension_class.load(self, lazy=lazy)

self.extensions[extension_name] = extension_instance

Expand Down Expand Up @@ -2414,20 +2430,20 @@ def _get_zarr_extension_group(self, mode="r+"):
return extension_group

@classmethod
def load(cls, sorting_analyzer):
def load(cls, sorting_analyzer, lazy=False):
ext = cls(sorting_analyzer)
ext.load_params()
ext.load_run_info()
if ext.run_info is not None:
if ext.run_info["run_completed"]:
ext.load_data()
ext.load_data(lazy=lazy)
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
return ext
else:
# this is for back-compatibility of old analyzers
ext.load_data()
ext.load_data(lazy=lazy)
if cls.need_backward_compatibility_on_load:
ext._handle_backward_compatibility_on_load()
if len(ext.data) > 0:
Expand Down Expand Up @@ -2527,7 +2543,7 @@ def load_params(self):

self.params = params

def load_data(self):
def load_data(self, lazy=False):
ext_data = None
if self.format == "binary_folder":
extension_folder = self._get_binary_extension_folder()
Expand All @@ -2550,7 +2566,8 @@ def load_data(self):
# and have a link to the old buffer on windows then it fails
# ext_data = np.load(ext_data_file, mmap_mode="r")
# so we go back to full loading
ext_data = np.load(ext_data_file)
kwargs = dict(mmap_mode="r") if lazy else dict()
ext_data = np.load(ext_data_file, **kwargs)
elif ext_data_file.suffix == ".csv":
import pandas as pd

Expand Down Expand Up @@ -2587,7 +2604,10 @@ def load_data(self):
ext_data = ext_data_[0]
else:
# this load in memmory
ext_data = np.array(ext_data_)
if lazy:
ext_data = ext_data_
else:
ext_data = np.array(ext_data_)
self.set_data(ext_data_name, ext_data)

if len(self.data) == 0:
Expand Down
6 changes: 3 additions & 3 deletions src/spikeinterface/core/sortingfolder.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ class NumpyFolderSorting(BaseSorting):
mode = "folder"
name = "NumpyFolder"

def __init__(self, folder_path):
def __init__(self, folder_path, mmap_mode=None):
folder_path = Path(folder_path)

with open(folder_path / "numpysorting_info.json", "r") as f:
Expand All @@ -36,7 +36,7 @@ def __init__(self, folder_path):

BaseSorting.__init__(self, sampling_frequency, unit_ids)

self.spikes = np.load(folder_path / "spikes.npy")
self.spikes = np.load(folder_path / "spikes.npy", mmap_mode=mmap_mode)

for segment_index in range(num_segments):
self.add_sorting_segment(SpikeVectorSortingSegment(self.spikes, segment_index, unit_ids))
Expand All @@ -47,7 +47,7 @@ def __init__(self, folder_path):
folder_metadata = folder_path
self.load_metadata_from_folder(folder_metadata)

self._kwargs = dict(folder_path=str(folder_path.absolute()))
self._kwargs = dict(folder_path=str(folder_path.absolute()), mmap_mode=mmap_mode)

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When are the _kwargs large? Do they need to be lazy?

Copy link
Copy Markdown
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

They are not. This is the kwargs to instantiate the Sorting object!


@staticmethod
def write_sorting(sorting, save_path):
Expand Down
2 changes: 2 additions & 0 deletions src/spikeinterface/core/zarrextractors.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,8 @@ def __init__(self, folder_path: Path | str, storage_options: dict | None = None,

BaseSorting.__init__(self, sampling_frequency, unit_ids)

# TODO: make a virtual memmap view of the spike vector or override to_spike_vector to behave like
# a memmap
spikes = np.zeros(len(spikes_group["sample_index"]), dtype=minimum_spike_dtype)
spikes["sample_index"] = spikes_group["sample_index"][:]
spikes["unit_index"] = spikes_group["unit_index"][:]
Expand Down
Loading