Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 55 additions & 17 deletions src/spikeinterface/postprocessing/principal_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -633,27 +633,65 @@ def _all_pc_extractor_chunk(segment_index, start_frame, end_frame, worker_ctx):
end = int(spike_times[i1 - 1] + nafter)
traces = recording.get_traces(start_frame=start, end_frame=end, segment_index=segment_index)

for i in range(i0, i1):
st = spike_times[i]
if st - start - nbefore < 0:
continue
if st - start + nafter > traces.shape[0]:
continue
nsamples = nbefore + nafter

wf = traces[st - start - nbefore : st - start + nafter, :]
# Extract all waveforms in the chunk at once
# valid_mask tracks which spikes have valid (in-bounds) waveforms
chunk_spike_times = spike_times[i0:i1]
offsets = chunk_spike_times - start - nbefore
valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0])
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is me probably misreading but I'm super bad at parsing these type of > and < in general. If we have to be less than or = to the shape couldn't we run into an issue where we are = to the shape which is out of bounds?

ie an array of (4,5) the shape[0] = 4, but if I try to index on 4 it will be an out of bounds error. Again I don't work on the PC code at all so maybe I'm completely wrong here.

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You mean that the second <= should be just <?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In that case, I agree. It should be

Suggested change
valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0])
valid_mask = (offsets >= 0) & (offsets + nsamples < traces.shape[0])

to avoid any indexing error.

E,g. if nsamples = 90, traces shape is 300, and you have a spike at 210, then it would be valid, but L654 will fail because it'll try to access traces[300]...

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, we could get rid of the valid_mask completely by extracting traces using a margin of max(nbefore, nafter). @galenlynch do you mind if I give it a try in this PR?

Copy link
Copy Markdown
Author

@galenlynch galenlynch Apr 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think there's a bug here. This matches the original semantics (if ... > traces.shape[0]: continue), and the original semantics were correct.

The relevant code:

  valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0])   # L642
  ...
  sample_indices = valid_offsets[:, None] + np.arange(nsamples)[None, :]  # L653
  all_wfs = traces[sample_indices]                                        # L654

np.arange(nsamples) produces [0, 1, ..., nsamples-1]. So for a spike with offset o, the fancy indices accessed are [o, o+1, ..., o + nsamples - 1]. The maximum index accessed is o + nsamples - 1.

For that to be in bounds, we need:
o + nsamples - 1 ≤ traces.shape[0] - 1o + nsamples ≤ traces.shape[0]

@alejoe91 your example is also not quite right:
With offset = 210, nsamples = 90:

  • offset + nsamples = 300 ≤ 300 passes the mask
  • sample_indices = 210 + [0..89] = [210, 211, ..., 299]
  • traces[[210, ..., 299]]: max index is 299, not 300.

Changing the bound to < would drop valid spikes at the end of the chunk.

Using get_chunk_with_margin is not the right fix because 1) I don't think there's a bug 2) would overfetch the data 3) would require other indexing logic changes.


You're right that valid_mask is leftover and unreachable logic replacing the original logic.

Look at how start and end are computed (lines 632–634):

start = int(spike_times[i0] - nbefore)
end = int(spike_times[i1 - 1] + nafter)
traces = recording.get_traces(start_frame=start, end_frame=end, ...)

The traces array is tight-bracketed around exactly the spikes in [i0, i1). Let's check the extremes:

  • First spike: offset = spike_times[i0] - start - nbefore = 0 → accesses [0, nsamples-1] (correct)
  • Last spike: offset = spike_times[i1-1] - spike_times[i0], so offset + nsamples = spike_times[i1-1] - spike_times[i0] + nbefore + nafter = end - start = traces.shape[0] → accesses up to traces.shape[0]-1 (correct)

So valid_mask is always all-True on the happy path. The only time anything gets dropped is at absolute segment boundaries, and those are already handled by the i0 += 1 / i1 -= 1 loop at lines 620–627 before start/end are even computed. The valid_mask is dead defensive code inherited from the original loop's continue checks.

We could delete it entirely:

# Before
chunk_spike_times = spike_times[i0:i1]
offsets = chunk_spike_times - start - nbefore
valid_mask = (offsets >= 0) & (offsets + nsamples <= traces.shape[0])
if not np.any(valid_mask):
    return
valid_offsets = offsets[valid_mask]
valid_indices = np.arange(i0, i1)[valid_mask]
n_valid = len(valid_offsets)

# After
offsets = spike_times[i0:i1] - start - nbefore
valid_indices = np.arange(i0, i1)
n_valid = i1 - i0
sample_indices = offsets[:, None] + np.arange(nsamples)[None, :]
all_wfs = traces[sample_indices]

What do you think?


unit_index = spike_labels[i]
chan_inds = unit_channels[unit_index]
if not np.any(valid_mask):
return

valid_offsets = offsets[valid_mask]
valid_indices = np.arange(i0, i1)[valid_mask]
n_valid = len(valid_offsets)

# Build waveform array: (n_valid, nsamples, n_channels)
# Use fancy indexing to extract all snippets at once
sample_indices = valid_offsets[:, None] + np.arange(nsamples)[None, :] # (n_valid, nsamples)
all_wfs = traces[sample_indices] # (n_valid, nsamples, n_channels)

# Vectorized PCA: batch by channel across all spikes in the chunk.
# For each unique channel, find all spikes that use it (via their unit's
# sparsity), extract waveforms, and call transform once.
valid_labels = spike_labels[valid_indices]

# Build a set of all channels used by spikes in this chunk
unique_unit_indices = np.unique(valid_labels)
chan_info: dict[int, list[tuple[np.ndarray, int]]] = {}
for unit_index in unique_unit_indices:
chan_inds = unit_channels[unit_index]
unit_mask = valid_labels == unit_index
unit_local_idxs = np.nonzero(unit_mask)[0]
for c, chan_ind in enumerate(chan_inds):
w = wf[:, chan_ind]
if w.size > 0:
w = w[None, :]
try:
all_pcs[i, :, c] = pca_model[chan_ind].transform(w)
except:
# this could happen if len(wfs) is less then n_comp for a channel
pass
if chan_ind not in chan_info:
chan_info[chan_ind] = []
chan_info[chan_ind].append((unit_local_idxs, c))

for chan_ind, unit_groups in chan_info.items():
# Concatenate all spike indices for this channel across units
all_local_idxs = np.concatenate([g[0] for g in unit_groups])
global_idxs = valid_indices[all_local_idxs]

# Batch waveforms for this channel: (n_spikes, nsamples)
wfs_batch = all_wfs[all_local_idxs, :, chan_ind]

if wfs_batch.size == 0:
continue

try:
pcs_batch = pca_model[chan_ind].transform(wfs_batch)
# Write results back — each unit group has a fixed channel position
offset = 0
for unit_local_idxs, c_pos in unit_groups:
n = len(unit_local_idxs)
all_pcs[global_idxs[offset : offset + n], :, c_pos] = pcs_batch[offset : offset + n]
offset += n
except Exception:
# this could happen if len(wfs) is less than n_comp for a channel
pass


def _init_work_all_pc_extractor(recording, sorting, all_pcs_args, nbefore, nafter, unit_channels, pca_model):
Expand Down
Loading