Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
219 changes: 217 additions & 2 deletions src/spikeinterface/preprocessing/resample.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,28 @@ class ResampleRecording(BasePreprocessor):
The recording extractor to be re-referenced
resample_rate : int
The resampling frequency
gap_tolerance_ms : float | None, default: None
Maximum acceptable gap size in milliseconds for automatic segmentation.

**Default behavior (None)**: If timestamp gaps are detected in the parent
recording's time vector, an error is raised with a detailed gap report.
This ensures users are aware of data discontinuities rather than silently
producing incorrect results.

**Opt-in segmentation**: Provide a value to automatically handle gaps via
section-wise resampling. Gaps larger than this threshold trigger section
splitting; gaps smaller than the threshold are ignored (data treated as
continuous). Within each contiguous section, resampling proceeds correctly.

In all cases, deviations smaller than 1.5 sample periods are never treated
as gaps, since sub-sample jitter and floating-point noise in time vectors
cannot represent dropped samples.

Examples:
- None (default): Error on any detected gaps
- 0.0: Strict mode — split on any gap >= 1.5 sample periods
- 1.0: Tolerate gaps up to 1 ms, split on larger gaps
- 100.0: Only major pauses (>100 ms) create sections
margin_ms : float, default: 100.0
Margin in ms for computations, will be used to decrease edge effects.
dtype : dtype or None, default: None
Expand All @@ -44,6 +66,7 @@ def __init__(
self,
recording,
resample_rate,
gap_tolerance_ms=None,
margin_ms=100.0,
dtype=None,
skip_checks=False,
Expand Down Expand Up @@ -73,12 +96,14 @@ def __init__(
recording.get_sampling_frequency(),
margin,
dtype,
gap_tolerance_ms,
)
)

self._kwargs = dict(
recording=recording,
resample_rate=resample_rate,
gap_tolerance_ms=gap_tolerance_ms,
margin_ms=margin_ms,
dtype=dtype,
skip_checks=skip_checks,
Expand All @@ -93,29 +118,120 @@ def __init__(
parent_rate,
margin,
dtype,
gap_tolerance_ms=None,
):
self._resample_rate = resample_rate
self._parent_segment = parent_recording_segment
self._parent_rate = parent_rate
self._margin = margin
self._dtype = dtype
self._has_gaps = False

# Compute time_vector or t_start, following the pattern from DecimateRecordingSegment.
# Do not use BasePreprocessorSegment because we have to reset the sampling rate!
if parent_recording_segment.time_vector is not None:
parent_tv = np.asarray(parent_recording_segment.time_vector)

# Detect gaps in the parent time vector.
# A true gap means at least one dropped sample, so dt >= 2 * expected_dt.
# Use 1.5 * expected_dt as the minimum threshold to avoid false positives
# from floating-point jitter while catching any real dropped samples.
expected_dt = 1.0 / parent_rate
min_gap_threshold = 1.5 * expected_dt
if gap_tolerance_ms is not None:
detection_threshold = max(min_gap_threshold, gap_tolerance_ms / 1000.0)
else:
detection_threshold = min_gap_threshold

diffs = np.diff(parent_tv)
gap_indices = np.flatnonzero(diffs > detection_threshold)

if len(gap_indices) > 0 and gap_tolerance_ms is None:
gap_sizes_ms = diffs[gap_indices] * 1000
gap_positions_s = parent_tv[gap_indices]
raise ValueError(
f"Detected {len(gap_indices)} timestamp gap(s) in the parent "
f"recording's time vector.\n"
f" Gap sizes (ms): {gap_sizes_ms}\n"
f" Gap positions (seconds): {gap_positions_s}\n"
f" Gap positions (parent sample indices): {gap_indices}\n"
f"To handle gaps automatically via section-wise resampling, "
f"pass gap_tolerance_ms=<threshold>. Gaps larger than the "
f"threshold will trigger section splitting; smaller gaps are "
f"treated as continuous."
)

# Build section boundaries: contiguous runs of samples between gaps.
# We call these "sections" (not "segments") to avoid confusion with
# the Segment concept in SpikeInterface/neo.
if len(gap_indices) == 0:
sec_boundaries_parent = np.array([[0, len(parent_tv)]], dtype=np.int64)
else:
self._has_gaps = True
starts = np.concatenate([[0], gap_indices + 1]).astype(np.int64)
ends = np.concatenate([gap_indices + 1, [len(parent_tv)]]).astype(np.int64)
sec_boundaries_parent = np.column_stack([starts, ends])

# Compute per-section output sample counts and cumulative boundaries.
K = len(sec_boundaries_parent)
sec_n_out = np.array(
[
int((sec_boundaries_parent[k, 1] - sec_boundaries_parent[k, 0]) / parent_rate * resample_rate)
for k in range(K)
],
dtype=np.int64,
)
sec_cumstart = np.zeros(K, dtype=np.int64)
sec_cumstart[1:] = np.cumsum(sec_n_out[:-1])
sec_boundaries_output = np.column_stack([sec_cumstart, sec_cumstart + sec_n_out])

self._sec_boundaries_parent = sec_boundaries_parent
self._sec_boundaries_output = sec_boundaries_output
self._sec_n_out = sec_n_out

# Compute time_vector
n_out = int(len(parent_tv) / parent_rate * resample_rate)

if parent_rate % resample_rate == 0:
q_int = int(parent_rate / resample_rate)
time_vector = parent_tv[::q_int][:n_out]
else:
if not self._has_gaps:
time_vector = parent_tv[::q_int][:n_out]
else:
# Section-wise slicing to keep time_vector consistent
# with _sec_boundaries_output
tv_pieces = []
for k in range(K):
p_start, p_end = sec_boundaries_parent[k]
n_out_k = sec_n_out[k]
if n_out_k == 0:
continue
tv_pieces.append(parent_tv[p_start:p_end:q_int][:n_out_k])
time_vector = np.concatenate(tv_pieces)
elif not self._has_gaps:
# Non-integer ratio, no gaps: existing fast path
warnings.warn(
"Resampling with a non-integer ratio requires interpolating the time_vector. "
"An integer ratio (parent_rate / resample_rate) is more performant."
)
parent_indices = np.linspace(0, len(parent_tv) - 1, n_out)
time_vector = np.interp(parent_indices, np.arange(len(parent_tv)), parent_tv)
else:
# Non-integer ratio with gaps: per-section interpolation
warnings.warn(
"Resampling with a non-integer ratio requires interpolating the time_vector. "
"An integer ratio (parent_rate / resample_rate) is more performant."
)
tv_pieces = []
for k in range(K):
p_start, p_end = sec_boundaries_parent[k]
n_out_k = sec_n_out[k]
if n_out_k == 0:
continue
sec_parent_tv = parent_tv[p_start:p_end]
sec_len = p_end - p_start
sec_indices = np.linspace(0, sec_len - 1, n_out_k)
tv_pieces.append(np.interp(sec_indices, np.arange(sec_len), sec_parent_tv))
time_vector = np.concatenate(tv_pieces)

BaseRecordingSegment.__init__(self, sampling_frequency=None, t_start=None, time_vector=time_vector)
else:
Expand All @@ -129,6 +245,10 @@ def get_num_samples(self):
return int(self._parent_segment.get_num_samples() / self._parent_rate * self._resample_rate)

def get_traces(self, start_frame, end_frame, channel_indices):
if self._has_gaps:
return self._get_traces_gapped(start_frame, end_frame, channel_indices)

# Original code path: no gaps (or no time_vector)
# get parent traces with margin
parent_start_frame, parent_end_frame = [
int((frame / self._resample_rate) * self._parent_rate) for frame in [start_frame, end_frame]
Expand Down Expand Up @@ -169,6 +289,101 @@ def get_traces(self, start_frame, end_frame, channel_indices):
resampled_traces = resampled_traces[left_margin_rs : num - right_margin_rs]
return resampled_traces.astype(self._dtype)

def _get_traces_gapped(self, start_frame, end_frame, channel_indices):
"""Resample traces section-by-section, avoiding FFT processing across gaps."""
from scipy import signal

if start_frame == end_frame:
# Determine n_channels from parent
n_channels = (
len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1]
)
return np.empty((0, n_channels), dtype=self._dtype)

# Find which sections overlap [start_frame, end_frame) in output space.
# _sec_boundaries_output[k] = [out_start_k, out_end_k)
sec_ends = self._sec_boundaries_output[:, 1]
sec_starts = self._sec_boundaries_output[:, 0]
first_sec = int(np.searchsorted(sec_ends, start_frame, side="right"))
last_sec = int(np.searchsorted(sec_starts, end_frame, side="left")) - 1
first_sec = max(first_sec, 0)
last_sec = min(last_sec, len(self._sec_n_out) - 1)

is_integer_ratio = (self._parent_rate % self._resample_rate) == 0

pieces = []
for k in range(first_sec, last_sec + 1):
out_start_k = int(self._sec_boundaries_output[k, 0])
out_end_k = int(self._sec_boundaries_output[k, 1])
par_start_k = int(self._sec_boundaries_parent[k, 0])
par_end_k = int(self._sec_boundaries_parent[k, 1])
sec_n_parent = par_end_k - par_start_k
sec_n_output = int(self._sec_n_out[k])

if sec_n_output == 0:
continue

# Clip the output range to the requested [start_frame, end_frame)
local_out_start = max(start_frame, out_start_k) - out_start_k
local_out_end = min(end_frame, out_end_k) - out_start_k

if local_out_end <= local_out_start:
continue

# Map within-section output frames to within-section parent frames
local_par_start = int((local_out_start / self._resample_rate) * self._parent_rate)
local_par_end = int((local_out_end / self._resample_rate) * self._parent_rate)
local_par_start = max(0, min(local_par_start, sec_n_parent))
local_par_end = max(0, min(local_par_end, sec_n_parent))

# Apply margin within section boundaries only (do not cross gaps)
left_margin = min(self._margin, local_par_start)
right_margin = min(self._margin, sec_n_parent - local_par_end)

par_fetch_start = par_start_k + local_par_start - left_margin
par_fetch_end = par_start_k + local_par_end + right_margin

# Fetch parent traces for this section's sub-chunk
parent_traces = self._parent_segment.get_traces(par_fetch_start, par_fetch_end, channel_indices).astype(
np.float32
)

# Apply reflect padding if margin was truncated at section edge
pad_left = self._margin - left_margin
pad_right = self._margin - right_margin
if pad_left > 0 or pad_right > 0:
parent_traces = np.pad(parent_traces, [(pad_left, pad_right), (0, 0)], mode="reflect")
left_margin = self._margin
right_margin = self._margin

# Compute resampled margins
left_margin_rs = int((left_margin / self._parent_rate) * self._resample_rate)
right_margin_rs = int((right_margin / self._parent_rate) * self._resample_rate)

# Total output samples including margins
num = int(local_out_end - local_out_start) + left_margin_rs + right_margin_rs

# Resample this section
if is_integer_ratio:
q = int(self._parent_rate / self._resample_rate)
resampled = signal.decimate(parent_traces, q=q, axis=0)
if np.any(np.isnan(resampled)):
resampled = signal.resample(parent_traces, num, axis=0)
else:
resampled = signal.resample(parent_traces, num, axis=0)

# Trim margins
resampled = resampled[left_margin_rs : num - right_margin_rs]
pieces.append(resampled)

if len(pieces) == 0:
n_channels = (
len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1]
)
return np.empty((0, n_channels), dtype=self._dtype)
result = np.concatenate(pieces, axis=0) if len(pieces) > 1 else pieces[0]
return result.astype(self._dtype)


resample = define_function_handling_dict_from_class(source_class=ResampleRecording, name="resample")

Expand Down
Loading
Loading