diff --git a/src/spikeinterface/core/basesorting.py b/src/spikeinterface/core/basesorting.py index cb68f3d455..4aa266f956 100644 --- a/src/spikeinterface/core/basesorting.py +++ b/src/spikeinterface/core/basesorting.py @@ -180,6 +180,7 @@ def get_unit_spike_train( segment_index=segment_index, start_time=start_time, end_time=end_time, + use_cache=use_cache, ) segment_index = self._check_segment_index(segment_index) @@ -212,6 +213,7 @@ def get_unit_spike_train_in_seconds( segment_index: int | None = None, start_time: float | None = None, end_time: float | None = None, + use_cache: bool = True, ) -> np.ndarray: """ Get spike train for a unit in seconds. @@ -236,6 +238,10 @@ def get_unit_spike_train_in_seconds( The start time in seconds for spike train extraction end_time : float or None, default: None The end time in seconds for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using Returns ------- @@ -258,7 +264,7 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) spike_times = self.sample_index_to_time(spike_frames, segment_index=segment_index) @@ -288,13 +294,172 @@ def get_unit_spike_train_in_seconds( start_frame=start_frame, end_frame=end_frame, return_times=False, - use_cache=True, + use_cache=use_cache, ) t_start = segment._t_start if segment._t_start is not None else 0 spike_times = spike_frames / self.get_sampling_frequency() return t_start + spike_times + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_frame: int | None = None, + end_frame: int | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units. + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_frame : int or None, default: None + The start frame for spike train extraction + end_frame : int or None, default: None + The end frame for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times or frames) + """ + if return_times: + start_time = ( + self.sample_index_to_time(start_frame, segment_index=segment_index) if start_frame is not None else None + ) + end_time = ( + self.sample_index_to_time(end_frame, segment_index=segment_index) if end_frame is not None else None + ) + + return self.get_unit_spike_trains_in_seconds( + unit_ids=unit_ids, + segment_index=segment_index, + start_time=start_time, + end_time=end_time, + use_cache=use_cache, + ) + + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + if use_cache: + # TODO: speed things up + ordered_spike_vector, slices = self.to_reordered_spike_vector( + lexsort=("sample_index", "segment_index", "unit_index"), + return_order=False, + return_slices=True, + ) + unit_indices = self.ids_to_indices(unit_ids) + spike_trains = {} + for unit_index, unit_id in zip(unit_indices, unit_ids): + sl0, sl1 = slices[unit_index, segment_index, :] + spikes = ordered_spike_vector[sl0:sl1] + spike_frames = spikes["sample_index"] + if start_frame is not None: + start = np.searchsorted(spike_frames, start_frame) + spike_frames = spike_frames[start:] + if end_frame is not None: + end = np.searchsorted(spike_frames, end_frame) + spike_frames = spike_frames[:end] + spike_trains[unit_id] = spike_frames + else: + spike_trains = segment.get_unit_spike_trains( + unit_ids=unit_ids, start_frame=start_frame, end_frame=end_frame + ) + return spike_trains + + def get_unit_spike_trains_in_seconds( + self, + unit_ids: np.ndarray | list, + segment_index: int | None = None, + start_time: float | None = None, + end_time: float | None = None, + return_times: bool = False, + use_cache: bool = True, + ) -> dict[int | str, np.ndarray]: + """Return spike trains for multiple units in seconds + + Parameters + ---------- + unit_ids : np.ndarray | list + Unit ids to retrieve spike trains for + segment_index : int or None, default: None + The segment index to retrieve spike train from. + For multi-segment objects, it is required + start_time : float or None, default: None + The start time in seconds for spike train extraction + end_time : float or None, default: None + The end time in seconds for spike train extraction + return_times : bool, default: False + If True, returns spike times in seconds instead of frames + use_cache : bool, default: True + If True, then precompute (or use) the to_reordered_spike_vector using + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit ids and values are spike trains (arrays of spike times in seconds) + """ + segment_index = self._check_segment_index(segment_index) + segment = self.segments[segment_index] + + # If sorting has a registered recording, get the frames and get the times from the recording + # Note that this take into account the segment start time of the recording + spike_times = {} + if self.has_recording(): + # Get all the spike times and then slice them + start_frame = None + end_frame = None + spike_train_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + + for unit_id in unit_ids: + spike_frames = self.sample_index_to_time(spike_train_frames[unit_id], segment_index=segment_index) + + # Filter to return only the spikes within the specified time range + if start_time is not None: + spike_frames = spike_frames[spike_frames >= start_time] + if end_time is not None: + spike_frames = spike_frames[spike_frames <= end_time] + + spike_times[unit_id] = spike_frames + + return spike_times + + # If no recording attached and all back to frame-based conversion + # Get spike train in frames and convert to times using traditional method + start_frame = self.time_to_sample_index(start_time, segment_index=segment_index) if start_time else None + end_frame = self.time_to_sample_index(end_time, segment_index=segment_index) if end_time else None + + spike_frames = self.get_unit_spike_trains( + unit_ids=unit_ids, + segment_index=segment_index, + start_frame=start_frame, + end_frame=end_frame, + return_times=False, + use_cache=use_cache, + ) + for unit_id in unit_ids: + spike_frames_unit = spike_frames[unit_id] + t_start = segment._t_start if segment._t_start is not None else 0 + spike_times[unit_id] = spike_frames_unit / self.get_sampling_frequency() + t_start + return spike_times + def register_recording(self, recording, check_spike_frames: bool = True): """ Register a recording to the sorting. If the sorting and recording both contain @@ -978,7 +1143,7 @@ def to_reordered_spike_vector( s1 = seg_slices[segment_index + 1] slices[unit_index, segment_index, :] = [u0 + s0, u0 + s1] - elif ("sample_index", "unit_index", "segment_index"): + elif lexsort == ("sample_index", "unit_index", "segment_index"): slices = np.zeros((num_segments, num_units, 2), dtype=np.int64) seg_slices = np.searchsorted(ordered_spikes["segment_index"], np.arange(num_segments + 1), side="left") for segment_index in range(self.get_num_segments()): @@ -1083,7 +1248,7 @@ def __init__(self, t_start=None): def get_unit_spike_train( self, - unit_id, + unit_id: int | str, start_frame: int | None = None, end_frame: int | None = None, ) -> np.ndarray: @@ -1091,18 +1256,51 @@ def get_unit_spike_train( Parameters ---------- - unit_id + unit_id : int | str + The unit id for which to get the spike train. start_frame : int, default: None + The start frame for the spike train. If None, it is set to the beginning of the segment. end_frame : int, default: None + The end frame for the spike train. If None, it is set to the end of the segment. + Returns ------- np.ndarray - + The spike train for the given unit id and time interval. """ # must be implemented in subclass raise NotImplementedError + def get_unit_spike_trains( + self, + unit_ids: np.ndarray | list, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict[int | str, np.ndarray]: + """Get the spike trains for several units. + Can be implemented in subclass for performance but the default implementation is to call + get_unit_spike_train for each unit_id. + + Parameters + ---------- + unit_ids : numpy.array or list + The unit ids for which to get the spike trains. + start_frame : int, default: None + The start frame for the spike trains. If None, it is set to the beginning of the segment. + end_frame : int, default: None + The end frame for the spike trains. If None, it is set to the end of the segment. + + Returns + ------- + dict[int | str, np.ndarray] + A dictionary where keys are unit_ids and values are the corresponding spike trains. + """ + spike_trains = {} + for unit_id in unit_ids: + spike_trains[unit_id] = self.get_unit_spike_train(unit_id, start_frame=start_frame, end_frame=end_frame) + return spike_trains + class SpikeVectorSortingSegment(BaseSortingSegment): """ diff --git a/src/spikeinterface/core/tests/test_basesorting.py b/src/spikeinterface/core/tests/test_basesorting.py index 6c06b212b8..3d56d3e4e5 100644 --- a/src/spikeinterface/core/tests/test_basesorting.py +++ b/src/spikeinterface/core/tests/test_basesorting.py @@ -310,6 +310,37 @@ def test_select_periods(): np.testing.assert_array_equal(sliced_sorting.to_spike_vector(), sliced_sorting_array.to_spike_vector()) +@pytest.mark.parametrize("use_cache", [False, True]) +def test_get_unit_spike_trains(use_cache): + sampling_frequency = 10_000.0 + duration = 1.0 + num_samples = int(sampling_frequency * duration) + num_units = 10 + sorting = generate_sorting(durations=[duration], sampling_frequency=sampling_frequency, num_units=num_units) + + all_spike_trains = sorting.get_unit_spike_trains(unit_ids=sorting.unit_ids, use_cache=use_cache) + assert isinstance(all_spike_trains, dict) + assert set(all_spike_trains.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain = sorting.get_unit_spike_train(segment_index=0, unit_id=unit_id, use_cache=use_cache) + assert np.array_equal(all_spike_trains[unit_id], spiketrain) + + # test with times + spike_trains_times = sorting.get_unit_spike_trains_in_seconds( + unit_ids=sorting.unit_ids, return_times=True, use_cache=use_cache + ) + assert isinstance(spike_trains_times, dict) + assert set(spike_trains_times.keys()) == set(sorting.unit_ids) + for unit_id in sorting.unit_ids: + spiketrain = sorting.get_unit_spike_train( + segment_index=0, unit_id=unit_id, use_cache=use_cache, return_times=True + ) + spiketrain_times = sorting.get_unit_spike_train_in_seconds( + segment_index=0, unit_id=unit_id, use_cache=use_cache + ) + assert np.allclose(spiketrain_times, spiketrain) + + if __name__ == "__main__": import tempfile diff --git a/src/spikeinterface/core/unitsselectionsorting.py b/src/spikeinterface/core/unitsselectionsorting.py index 59356db976..d8d2d92afb 100644 --- a/src/spikeinterface/core/unitsselectionsorting.py +++ b/src/spikeinterface/core/unitsselectionsorting.py @@ -59,11 +59,15 @@ def _compute_and_cache_spike_vector(self) -> None: all_old_unit_ids=self._parent_sorting.unit_ids, all_new_unit_ids=self._unit_ids, ) - # lexsort by segment_index, sample_index, unit_index - sort_indices = np.lexsort( - (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) - ) - self._cached_spike_vector = spike_vector[sort_indices] + # lexsort by segment_index, sample_index, unit_index, only if needed + # (remapping can change the order of unit indices) + if len(self._renamed_unit_ids) > 1 and np.diff(self.ids_to_indices(self._renamed_unit_ids)).min() < 0: + sort_indices = np.lexsort( + (spike_vector["unit_index"], spike_vector["sample_index"], spike_vector["segment_index"]) + ) + self._cached_spike_vector = spike_vector[sort_indices] + else: + self._cached_spike_vector = spike_vector class UnitsSelectionSortingSegment(BaseSortingSegment): @@ -81,3 +85,12 @@ def get_unit_spike_train( unit_id_parent = self._ids_conversion[unit_id] times = self._parent_segment.get_unit_spike_train(unit_id_parent, start_frame, end_frame) return times + + def get_unit_spike_trains( + self, + unit_ids, + start_frame: int | None = None, + end_frame: int | None = None, + ) -> dict: + unit_ids_parent = [self._ids_conversion[unit_id] for unit_id in unit_ids] + return self._parent_segment.get_unit_spike_trains(unit_ids_parent, start_frame, end_frame)