From 1e319199bdd1ab3d421b265ea42354c3a76afbf3 Mon Sep 17 00:00:00 2001 From: Graham Findlay Date: Wed, 8 Apr 2026 13:51:40 -0500 Subject: [PATCH 1/2] Fix resampling of gapped recordings. --- src/spikeinterface/preprocessing/resample.py | 215 ++++++++++++++- .../preprocessing/tests/test_resample.py | 257 +++++++++++++++++- 2 files changed, 468 insertions(+), 4 deletions(-) diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 902bd6d176..0c574417fd 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -26,6 +26,28 @@ class ResampleRecording(BasePreprocessor): The recording extractor to be re-referenced resample_rate : int The resampling frequency + gap_tolerance_ms : float | None, default: None + Maximum acceptable gap size in milliseconds for automatic segmentation. + + **Default behavior (None)**: If timestamp gaps are detected in the parent + recording's time vector, an error is raised with a detailed gap report. + This ensures users are aware of data discontinuities rather than silently + producing incorrect results. + + **Opt-in segmentation**: Provide a value to automatically handle gaps via + section-wise resampling. Gaps larger than this threshold trigger section + splitting; gaps smaller than the threshold are ignored (data treated as + continuous). Within each contiguous section, resampling proceeds correctly. + + In all cases, deviations smaller than 1.5 sample periods are never treated + as gaps, since sub-sample jitter and floating-point noise in time vectors + cannot represent dropped samples. + + Examples: + - None (default): Error on any detected gaps + - 0.0: Strict mode — split on any gap >= 1.5 sample periods + - 1.0: Tolerate gaps up to 1 ms, split on larger gaps + - 100.0: Only major pauses (>100 ms) create sections margin_ms : float, default: 100.0 Margin in ms for computations, will be used to decrease edge effects. dtype : dtype or None, default: None @@ -44,6 +66,7 @@ def __init__( self, recording, resample_rate, + gap_tolerance_ms=None, margin_ms=100.0, dtype=None, skip_checks=False, @@ -73,12 +96,14 @@ def __init__( recording.get_sampling_frequency(), margin, dtype, + gap_tolerance_ms, ) ) self._kwargs = dict( recording=recording, resample_rate=resample_rate, + gap_tolerance_ms=gap_tolerance_ms, margin_ms=margin_ms, dtype=dtype, skip_checks=skip_checks, @@ -93,29 +118,120 @@ def __init__( parent_rate, margin, dtype, + gap_tolerance_ms=None, ): self._resample_rate = resample_rate self._parent_segment = parent_recording_segment self._parent_rate = parent_rate self._margin = margin self._dtype = dtype + self._has_gaps = False # Compute time_vector or t_start, following the pattern from DecimateRecordingSegment. # Do not use BasePreprocessorSegment because we have to reset the sampling rate! if parent_recording_segment.time_vector is not None: parent_tv = np.asarray(parent_recording_segment.time_vector) + + # Detect gaps in the parent time vector. + # A true gap means at least one dropped sample, so dt >= 2 * expected_dt. + # Use 1.5 * expected_dt as the minimum threshold to avoid false positives + # from floating-point jitter while catching any real dropped samples. + expected_dt = 1.0 / parent_rate + min_gap_threshold = 1.5 * expected_dt + if gap_tolerance_ms is not None: + detection_threshold = max(min_gap_threshold, gap_tolerance_ms / 1000.0) + else: + detection_threshold = min_gap_threshold + + diffs = np.diff(parent_tv) + gap_indices = np.flatnonzero(diffs > detection_threshold) + + if len(gap_indices) > 0 and gap_tolerance_ms is None: + gap_sizes_ms = diffs[gap_indices] * 1000 + gap_positions_s = parent_tv[gap_indices] + raise ValueError( + f"Detected {len(gap_indices)} timestamp gap(s) in the parent " + f"recording's time vector.\n" + f" Gap sizes (ms): {gap_sizes_ms}\n" + f" Gap positions (seconds): {gap_positions_s}\n" + f" Gap positions (parent sample indices): {gap_indices}\n" + f"To handle gaps automatically via section-wise resampling, " + f"pass gap_tolerance_ms=. Gaps larger than the " + f"threshold will trigger section splitting; smaller gaps are " + f"treated as continuous." + ) + + # Build section boundaries: contiguous runs of samples between gaps. + # We call these "sections" (not "segments") to avoid confusion with + # the Segment concept in SpikeInterface/neo. + if len(gap_indices) == 0: + sec_boundaries_parent = np.array([[0, len(parent_tv)]], dtype=np.int64) + else: + self._has_gaps = True + starts = np.concatenate([[0], gap_indices + 1]).astype(np.int64) + ends = np.concatenate([gap_indices + 1, [len(parent_tv)]]).astype(np.int64) + sec_boundaries_parent = np.column_stack([starts, ends]) + + # Compute per-section output sample counts and cumulative boundaries. + K = len(sec_boundaries_parent) + sec_n_out = np.array( + [ + int((sec_boundaries_parent[k, 1] - sec_boundaries_parent[k, 0]) / parent_rate * resample_rate) + for k in range(K) + ], + dtype=np.int64, + ) + sec_cumstart = np.zeros(K, dtype=np.int64) + sec_cumstart[1:] = np.cumsum(sec_n_out[:-1]) + sec_boundaries_output = np.column_stack([sec_cumstart, sec_cumstart + sec_n_out]) + + self._sec_boundaries_parent = sec_boundaries_parent + self._sec_boundaries_output = sec_boundaries_output + self._sec_n_out = sec_n_out + + # Compute time_vector n_out = int(len(parent_tv) / parent_rate * resample_rate) if parent_rate % resample_rate == 0: q_int = int(parent_rate / resample_rate) - time_vector = parent_tv[::q_int][:n_out] - else: + if not self._has_gaps: + time_vector = parent_tv[::q_int][:n_out] + else: + # Section-wise slicing to keep time_vector consistent + # with _sec_boundaries_output + tv_pieces = [] + for k in range(K): + p_start, p_end = sec_boundaries_parent[k] + n_out_k = sec_n_out[k] + if n_out_k == 0: + continue + tv_pieces.append(parent_tv[p_start:p_end:q_int][:n_out_k]) + time_vector = np.concatenate(tv_pieces) + elif not self._has_gaps: + # Non-integer ratio, no gaps: existing fast path warnings.warn( "Resampling with a non-integer ratio requires interpolating the time_vector. " "An integer ratio (parent_rate / resample_rate) is more performant." ) parent_indices = np.linspace(0, len(parent_tv) - 1, n_out) time_vector = np.interp(parent_indices, np.arange(len(parent_tv)), parent_tv) + else: + # Non-integer ratio with gaps: per-section interpolation + warnings.warn( + "Resampling with a non-integer ratio requires interpolating the time_vector. " + "An integer ratio (parent_rate / resample_rate) is more performant." + ) + tv_pieces = [] + for k in range(K): + p_start, p_end = sec_boundaries_parent[k] + n_out_k = sec_n_out[k] + if n_out_k == 0: + continue + sec_parent_tv = parent_tv[p_start:p_end] + sec_len = p_end - p_start + sec_indices = np.linspace(0, sec_len - 1, n_out_k) + tv_pieces.append(np.interp(sec_indices, np.arange(sec_len), sec_parent_tv)) + time_vector = np.concatenate(tv_pieces) BaseRecordingSegment.__init__(self, sampling_frequency=None, t_start=None, time_vector=time_vector) else: @@ -129,6 +245,10 @@ def get_num_samples(self): return int(self._parent_segment.get_num_samples() / self._parent_rate * self._resample_rate) def get_traces(self, start_frame, end_frame, channel_indices): + if self._has_gaps: + return self._get_traces_gapped(start_frame, end_frame, channel_indices) + + # Original code path: no gaps (or no time_vector) # get parent traces with margin parent_start_frame, parent_end_frame = [ int((frame / self._resample_rate) * self._parent_rate) for frame in [start_frame, end_frame] @@ -169,6 +289,97 @@ def get_traces(self, start_frame, end_frame, channel_indices): resampled_traces = resampled_traces[left_margin_rs : num - right_margin_rs] return resampled_traces.astype(self._dtype) + def _get_traces_gapped(self, start_frame, end_frame, channel_indices): + """Resample traces section-by-section, avoiding FFT processing across gaps.""" + from scipy import signal + + if start_frame == end_frame: + # Determine n_channels from parent + n_channels = len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1] + return np.empty((0, n_channels), dtype=self._dtype) + + # Find which sections overlap [start_frame, end_frame) in output space. + # _sec_boundaries_output[k] = [out_start_k, out_end_k) + sec_ends = self._sec_boundaries_output[:, 1] + sec_starts = self._sec_boundaries_output[:, 0] + first_sec = int(np.searchsorted(sec_ends, start_frame, side="right")) + last_sec = int(np.searchsorted(sec_starts, end_frame, side="left")) - 1 + first_sec = max(first_sec, 0) + last_sec = min(last_sec, len(self._sec_n_out) - 1) + + is_integer_ratio = (self._parent_rate % self._resample_rate) == 0 + + pieces = [] + for k in range(first_sec, last_sec + 1): + out_start_k = int(self._sec_boundaries_output[k, 0]) + out_end_k = int(self._sec_boundaries_output[k, 1]) + par_start_k = int(self._sec_boundaries_parent[k, 0]) + par_end_k = int(self._sec_boundaries_parent[k, 1]) + sec_n_parent = par_end_k - par_start_k + sec_n_output = int(self._sec_n_out[k]) + + if sec_n_output == 0: + continue + + # Clip the output range to the requested [start_frame, end_frame) + local_out_start = max(start_frame, out_start_k) - out_start_k + local_out_end = min(end_frame, out_end_k) - out_start_k + + if local_out_end <= local_out_start: + continue + + # Map within-section output frames to within-section parent frames + local_par_start = int((local_out_start / self._resample_rate) * self._parent_rate) + local_par_end = int((local_out_end / self._resample_rate) * self._parent_rate) + local_par_start = max(0, min(local_par_start, sec_n_parent)) + local_par_end = max(0, min(local_par_end, sec_n_parent)) + + # Apply margin within section boundaries only (do not cross gaps) + left_margin = min(self._margin, local_par_start) + right_margin = min(self._margin, sec_n_parent - local_par_end) + + par_fetch_start = par_start_k + local_par_start - left_margin + par_fetch_end = par_start_k + local_par_end + right_margin + + # Fetch parent traces for this section's sub-chunk + parent_traces = self._parent_segment.get_traces( + par_fetch_start, par_fetch_end, channel_indices + ).astype(np.float32) + + # Apply reflect padding if margin was truncated at section edge + pad_left = self._margin - left_margin + pad_right = self._margin - right_margin + if pad_left > 0 or pad_right > 0: + parent_traces = np.pad(parent_traces, [(pad_left, pad_right), (0, 0)], mode="reflect") + left_margin = self._margin + right_margin = self._margin + + # Compute resampled margins + left_margin_rs = int((left_margin / self._parent_rate) * self._resample_rate) + right_margin_rs = int((right_margin / self._parent_rate) * self._resample_rate) + + # Total output samples including margins + num = int(local_out_end - local_out_start) + left_margin_rs + right_margin_rs + + # Resample this section + if is_integer_ratio: + q = int(self._parent_rate / self._resample_rate) + resampled = signal.decimate(parent_traces, q=q, axis=0) + if np.any(np.isnan(resampled)): + resampled = signal.resample(parent_traces, num, axis=0) + else: + resampled = signal.resample(parent_traces, num, axis=0) + + # Trim margins + resampled = resampled[left_margin_rs : num - right_margin_rs] + pieces.append(resampled) + + if len(pieces) == 0: + n_channels = len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1] + return np.empty((0, n_channels), dtype=self._dtype) + result = np.concatenate(pieces, axis=0) if len(pieces) > 1 else pieces[0] + return result.astype(self._dtype) + resample = define_function_handling_dict_from_class(source_class=ResampleRecording, name="resample") diff --git a/src/spikeinterface/preprocessing/tests/test_resample.py b/src/spikeinterface/preprocessing/tests/test_resample.py index c53b7b42bd..bc2ec27bbe 100644 --- a/src/spikeinterface/preprocessing/tests/test_resample.py +++ b/src/spikeinterface/preprocessing/tests/test_resample.py @@ -3,6 +3,7 @@ import numpy as np +import pytest DEBUG = False # DEBUG = True @@ -73,6 +74,20 @@ def get_fft(traces, sampling_frequency): return xf, nyf +def _make_gapped_recording(sampling_frequency=30000, n_channels=2, sec1_duration=1.0, sec2_duration=1.0, gap_s=5.0): + """Helper: create a NumpyRecording with a time_vector gap between two sections.""" + n1 = int(sec1_duration * sampling_frequency) + n2 = int(sec2_duration * sampling_frequency) + n_total = n1 + n2 + traces = np.random.randn(n_total, n_channels).astype(np.float32) + rec = NumpyRecording(traces, sampling_frequency) + + tv = np.arange(n_total, dtype="float64") / sampling_frequency + tv[n1:] += gap_s + rec.set_times(tv) + return rec, n1, n2 + + def test_resample_freq_domain(): sampling_frequency = 3e4 duration = 10 @@ -240,7 +255,8 @@ def test_resample_does_not_mutate_parent(): def test_resample_preserves_time_vector_integer_ratio(): - """Resampling with integer ratio should slice the parent time_vector.""" + """Resampling with integer ratio should slice the parent time_vector, + preserving gaps when gap_tolerance_ms is provided.""" sampling_frequency = 30000 resample_rate = 500 n_samples = sampling_frequency * 2 @@ -254,7 +270,7 @@ def test_resample_preserves_time_vector_integer_ratio(): time_vector[midpoint:] += 5.0 parent_rec.set_times(time_vector) - resampled = resample(parent_rec, resample_rate) + resampled = resample(parent_rec, resample_rate, gap_tolerance_ms=1.0) assert resampled.has_time_vector() resampled_times = resampled.get_times() @@ -298,6 +314,236 @@ def test_resample_preserves_time_vector_non_integer_ratio(): assert np.isclose(resampled_times[0], 10.0, atol=1.0 / sampling_frequency) +def test_resample_errors_on_gaps_by_default(): + """With gap_tolerance_ms=None (default), a gapped time vector should raise ValueError.""" + rec, _, _ = _make_gapped_recording() + with pytest.raises(ValueError, match="timestamp gap"): + resample(rec, 500) + + +def test_resample_preserves_gaps_non_integer_ratio(): + """Non-integer ratio with gap_tolerance_ms should preserve the gap in the output time_vector.""" + sampling_frequency = 30000 + resample_rate = 700 # non-integer ratio + gap_s = 5.0 + rec, n1, n2 = _make_gapped_recording(sampling_frequency=sampling_frequency, gap_s=gap_s) + + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + resampled = resample(rec, resample_rate, gap_tolerance_ms=1.0) + + assert resampled.has_time_vector() + resampled_times = resampled.get_times() + assert len(resampled_times) == resampled.get_num_samples() + + # The gap should be preserved + diffs = np.diff(resampled_times) + normal_dt = 1.0 / resample_rate + gap_indices = np.where(diffs > normal_dt * 2)[0] + assert len(gap_indices) == 1, f"Expected 1 gap, found {len(gap_indices)}" + + # Gap size should be approximately gap_s (plus one normal dt) + assert np.isclose(diffs[gap_indices[0]], gap_s + normal_dt, atol=2 * normal_dt) + + # No timestamps should fall inside the gap + parent_tv = rec.get_times() + gap_start_t = parent_tv[n1 - 1] + gap_end_t = parent_tv[n1] + in_gap = (resampled_times > gap_start_t + normal_dt) & (resampled_times < gap_end_t - normal_dt) + assert not np.any(in_gap), "Output timestamps fall inside the gap" + + +def test_resample_traces_across_gap(): + """Section-wise resampling should match individually resampled sections. + + Build a gapped recording, resample it with gap_tolerance_ms, and verify + that each section's output matches what you'd get by resampling that + section alone (without the gap). This confirms that _get_traces_gapped + does not apply FFT processing across gap boundaries. + """ + sampling_frequency = 30000 + resample_rate = 700 # non-integer ratio + sec_duration = 2.0 + gap_s = 5.0 + + n1 = int(sec_duration * sampling_frequency) + n2 = int(sec_duration * sampling_frequency) + + # Build random traces (more realistic than a sinusoid) + rng = np.random.default_rng(42) + traces1 = rng.standard_normal((n1, 2)).astype(np.float32) + traces2 = rng.standard_normal((n2, 2)).astype(np.float32) + traces = np.concatenate([traces1, traces2], axis=0) + + t1 = np.arange(n1, dtype="float64") / sampling_frequency + t2 = np.arange(n2, dtype="float64") / sampling_frequency + sec_duration + gap_s + tv = np.concatenate([t1, t2]) + + rec = NumpyRecording(traces, sampling_frequency) + rec.set_times(tv) + + # Resample the gapped recording + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + resampled = resample(rec, resample_rate, gap_tolerance_ms=1.0) + + resampled_traces = resampled.get_traces() + n_out_1 = int(resampled.segments[0]._sec_n_out[0]) + n_out_2 = int(resampled.segments[0]._sec_n_out[1]) + assert resampled_traces.shape[0] == n_out_1 + n_out_2 + + # Resample each section independently (no gap in these recordings) + rec1 = NumpyRecording(traces1, sampling_frequency) + rec1.set_times(t1) + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + resampled1 = resample(rec1, resample_rate, gap_tolerance_ms=1.0) + ref_traces1 = resampled1.get_traces() + + rec2 = NumpyRecording(traces2, sampling_frequency) + rec2.set_times(t2) + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + resampled2 = resample(rec2, resample_rate, gap_tolerance_ms=1.0) + ref_traces2 = resampled2.get_traces() + + # Section 1 should match the independently resampled section 1 + gapped_s1 = resampled_traces[:n_out_1] + assert gapped_s1.shape == ref_traces1.shape, ( + f"Section 1 shape mismatch: {gapped_s1.shape} vs {ref_traces1.shape}" + ) + np.testing.assert_allclose(gapped_s1, ref_traces1, rtol=1e-5, atol=1e-5) + + # Section 2 should match the independently resampled section 2 + gapped_s2 = resampled_traces[n_out_1 : n_out_1 + n_out_2] + assert gapped_s2.shape == ref_traces2.shape, ( + f"Section 2 shape mismatch: {gapped_s2.shape} vs {ref_traces2.shape}" + ) + np.testing.assert_allclose(gapped_s2, ref_traces2, rtol=1e-5, atol=1e-5) + + +def test_resample_gapped_chunked_consistency(): + """Chunked .save() should match non-chunked for gapped recordings.""" + sampling_frequency = 30000 + resample_rate = 700 + rec, _, _ = _make_gapped_recording(sampling_frequency=sampling_frequency, sec1_duration=2.0, sec2_duration=2.0) + + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + resampled = resample(rec, resample_rate, gap_tolerance_ms=1.0) + + traces_full = resampled.get_traces() + chunk_size = resample_rate # 1 second chunks + saved = resampled.save(format="memory", chunk_size=chunk_size, n_jobs=1, progress_bar=False) + traces_chunked = saved.get_traces() + + assert traces_full.shape == traces_chunked.shape + # Interior samples should match closely (edges may have small differences) + sl = slice(chunk_size, -chunk_size) + rms = np.sqrt(np.mean(traces_full[sl] ** 2)) + if rms > 0: + error = np.sqrt(np.mean((traces_full[sl] - traces_chunked[sl]) ** 2)) + assert error / rms < 0.05, f"Chunked vs full RMS error ratio: {error / rms:.4f}" + + +def test_resample_no_gap_unchanged_behavior(): + """Uniform time_vector without gaps should produce identical results with or without gap_tolerance_ms.""" + sampling_frequency = 30000 + resample_rate = 700 + n_samples = sampling_frequency * 2 + traces = np.random.randn(n_samples, 1).astype(np.float32) + rec = NumpyRecording(traces, sampling_frequency) + + tv = np.arange(n_samples, dtype="float64") / sampling_frequency + 100.0 + rec.set_times(tv) + + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + # Without gap_tolerance_ms (no gaps, so no error) + resampled_default = resample(rec, resample_rate) + # With gap_tolerance_ms (no gaps to split on, should be identical) + resampled_tolerant = resample(rec, resample_rate, gap_tolerance_ms=1.0) + + np.testing.assert_array_equal(resampled_default.get_times(), resampled_tolerant.get_times()) + np.testing.assert_array_equal(resampled_default.get_traces(), resampled_tolerant.get_traces()) + + +def test_resample_multiple_gaps(): + """Recording with multiple gaps should produce the correct number of sections.""" + sampling_frequency = 30000 + resample_rate = 700 + n_per_sec = int(0.5 * sampling_frequency) # 0.5s per section + n_sections = 4 + n_total = n_per_sec * n_sections + traces = np.random.randn(n_total, 1).astype(np.float32) + rec = NumpyRecording(traces, sampling_frequency) + + # Create time_vector with 3 gaps (between 4 sections) + tv = np.arange(n_total, dtype="float64") / sampling_frequency + for i in range(1, n_sections): + tv[i * n_per_sec :] += (i * 10.0) # gaps of 10s, 20s, 30s cumulative offsets + rec.set_times(tv) + + import warnings as _warnings + + with _warnings.catch_warnings(): + _warnings.simplefilter("ignore") + resampled = resample(rec, resample_rate, gap_tolerance_ms=1.0) + + resampled_times = resampled.get_times() + diffs = np.diff(resampled_times) + normal_dt = 1.0 / resample_rate + + # Should detect 3 gaps + gap_indices = np.where(diffs > normal_dt * 2)[0] + assert len(gap_indices) == 3, f"Expected 3 gaps, found {len(gap_indices)}" + + # Traces and times should match in length + assert resampled.get_traces().shape[0] == len(resampled_times) + + +def test_resample_gap_tolerance_filtering(): + """Gaps smaller than gap_tolerance_ms should be treated as continuous.""" + sampling_frequency = 30000 + resample_rate = 500 # integer ratio for simplicity + + n1 = int(1.0 * sampling_frequency) + n2 = int(1.0 * sampling_frequency) + n3 = int(1.0 * sampling_frequency) + n_total = n1 + n2 + n3 + traces = np.random.randn(n_total, 1).astype(np.float32) + rec = NumpyRecording(traces, sampling_frequency) + + tv = np.arange(n_total, dtype="float64") / sampling_frequency + # Small gap (5ms = 0.005s) after section 1 — detectable but below 50ms tolerance + tv[n1:] += 0.005 + # Large gap (100ms = 0.1s) after section 2 + tv[n1 + n2 :] += 0.1 + rec.set_times(tv) + + # With tolerance of 50ms: only the 100ms gap triggers a section split + resampled = resample(rec, resample_rate, gap_tolerance_ms=50.0) + seg = resampled.segments[0] + assert seg._has_gaps, "Should detect at least one gap" + n_sections = len(seg._sec_n_out) + assert n_sections == 2, f"Expected 2 sections (split at 100ms gap), found {n_sections}" + + # With tolerance of 1.0ms: both gaps trigger section splits + # (5ms > 1ms, 100ms > 1ms) + resampled_strict = resample(rec, resample_rate, gap_tolerance_ms=1.0) + seg_strict = resampled_strict.segments[0] + n_sections_strict = len(seg_strict._sec_n_out) + assert n_sections_strict == 3, f"Expected 3 sections with 1ms tolerance, found {n_sections_strict}" + + if __name__ == "__main__": test_resample_freq_domain() test_resample_by_chunks() @@ -305,3 +551,10 @@ def test_resample_preserves_time_vector_non_integer_ratio(): test_resample_does_not_mutate_parent() test_resample_preserves_time_vector_integer_ratio() test_resample_preserves_time_vector_non_integer_ratio() + test_resample_errors_on_gaps_by_default() + test_resample_preserves_gaps_non_integer_ratio() + test_resample_traces_across_gap() + test_resample_gapped_chunked_consistency() + test_resample_no_gap_unchanged_behavior() + test_resample_multiple_gaps() + test_resample_gap_tolerance_filtering() From bfd9a661653f3043f6e899420c9dc3f87feed2aa Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 9 Apr 2026 00:23:58 +0000 Subject: [PATCH 2/2] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/spikeinterface/preprocessing/resample.py | 14 +++++++++----- .../preprocessing/tests/test_resample.py | 10 +++------- 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/src/spikeinterface/preprocessing/resample.py b/src/spikeinterface/preprocessing/resample.py index 0c574417fd..8271f01fa2 100644 --- a/src/spikeinterface/preprocessing/resample.py +++ b/src/spikeinterface/preprocessing/resample.py @@ -295,7 +295,9 @@ def _get_traces_gapped(self, start_frame, end_frame, channel_indices): if start_frame == end_frame: # Determine n_channels from parent - n_channels = len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1] + n_channels = ( + len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1] + ) return np.empty((0, n_channels), dtype=self._dtype) # Find which sections overlap [start_frame, end_frame) in output space. @@ -342,9 +344,9 @@ def _get_traces_gapped(self, start_frame, end_frame, channel_indices): par_fetch_end = par_start_k + local_par_end + right_margin # Fetch parent traces for this section's sub-chunk - parent_traces = self._parent_segment.get_traces( - par_fetch_start, par_fetch_end, channel_indices - ).astype(np.float32) + parent_traces = self._parent_segment.get_traces(par_fetch_start, par_fetch_end, channel_indices).astype( + np.float32 + ) # Apply reflect padding if margin was truncated at section edge pad_left = self._margin - left_margin @@ -375,7 +377,9 @@ def _get_traces_gapped(self, start_frame, end_frame, channel_indices): pieces.append(resampled) if len(pieces) == 0: - n_channels = len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1] + n_channels = ( + len(channel_indices) if channel_indices is not None else self._parent_segment.get_traces(0, 1).shape[1] + ) return np.empty((0, n_channels), dtype=self._dtype) result = np.concatenate(pieces, axis=0) if len(pieces) > 1 else pieces[0] return result.astype(self._dtype) diff --git a/src/spikeinterface/preprocessing/tests/test_resample.py b/src/spikeinterface/preprocessing/tests/test_resample.py index bc2ec27bbe..7f79ae24ee 100644 --- a/src/spikeinterface/preprocessing/tests/test_resample.py +++ b/src/spikeinterface/preprocessing/tests/test_resample.py @@ -413,16 +413,12 @@ def test_resample_traces_across_gap(): # Section 1 should match the independently resampled section 1 gapped_s1 = resampled_traces[:n_out_1] - assert gapped_s1.shape == ref_traces1.shape, ( - f"Section 1 shape mismatch: {gapped_s1.shape} vs {ref_traces1.shape}" - ) + assert gapped_s1.shape == ref_traces1.shape, f"Section 1 shape mismatch: {gapped_s1.shape} vs {ref_traces1.shape}" np.testing.assert_allclose(gapped_s1, ref_traces1, rtol=1e-5, atol=1e-5) # Section 2 should match the independently resampled section 2 gapped_s2 = resampled_traces[n_out_1 : n_out_1 + n_out_2] - assert gapped_s2.shape == ref_traces2.shape, ( - f"Section 2 shape mismatch: {gapped_s2.shape} vs {ref_traces2.shape}" - ) + assert gapped_s2.shape == ref_traces2.shape, f"Section 2 shape mismatch: {gapped_s2.shape} vs {ref_traces2.shape}" np.testing.assert_allclose(gapped_s2, ref_traces2, rtol=1e-5, atol=1e-5) @@ -489,7 +485,7 @@ def test_resample_multiple_gaps(): # Create time_vector with 3 gaps (between 4 sections) tv = np.arange(n_total, dtype="float64") / sampling_frequency for i in range(1, n_sections): - tv[i * n_per_sec :] += (i * 10.0) # gaps of 10s, 20s, 30s cumulative offsets + tv[i * n_per_sec :] += i * 10.0 # gaps of 10s, 20s, 30s cumulative offsets rec.set_times(tv) import warnings as _warnings