Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 28 additions & 8 deletions src/spikeinterface/sorters/internal/lupin.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,22 +42,25 @@ class LupinSorter(ComponentsBasedSorter):
"clustering_ms_after": 1.3,
"whitening_radius_um": 100.0,
"detection_radius_um": 50.0,
"features_radius_um": 75.0,
"features_radius_um": 120.0,
"split_radius_um": 60.0,
"template_radius_um": 100.0,
"merge_similarity_lag_ms": 0.5,
"freq_min": 150.0,
"freq_max": 7000.0,
"cache_preprocessing_mode": "auto",
"peak_sign": "neg",
"detect_threshold": 5,
"detect_threshold": 5.0,
"n_peaks_per_channel": 5000,
"n_svd_components_per_channel": 5,
"n_pca_features": 4,
"clustering_recursive_depth": 3,
"ms_before": 1.0,
"ms_after": 2.5,
"template_sparsify_threshold": 1.5,
"template_sparsify_threshold": 1.0,
"template_min_snr_ptp": 4.0,
"template_max_jitter_ms": 0.2,
"template_matching_engine": "circus-omp",
"min_firing_rate": 0.1,
"gather_mode": "memory",
"job_kwargs": {},
Expand All @@ -74,6 +77,11 @@ class LupinSorter(ComponentsBasedSorter):
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
"radius_um": "Radius for sparsity",
"whitening_radius_um": "Radius for whitening",
"detection_radius_um": "Radius for peak detection",
"features_radius_um": "Radius for sparsity in SVD features",
"split_radius_um": "Radius for the local split clustering",
"template_radius_um": "Radius for the sparsity of template before template matching",
"freq_min": "Low frequency",
"freq_max": "High frequency",
"peak_sign": "Sign of peaks neg/pos/both",
Expand All @@ -99,7 +107,7 @@ class LupinSorter(ComponentsBasedSorter):

@classmethod
def get_sorter_version(cls):
return "2025.12"
return "2026.01"

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
Expand Down Expand Up @@ -201,6 +209,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
dtype="float32",
mode="local",
radius_um=params["whitening_radius_um"],
seed=seed,
)

if params["apply_motion_correction"]:
Expand All @@ -219,19 +228,23 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):

# Cache in mem or folder
cache_folder = sorter_output_folder / "cache_preprocessing"
recording_for_analyzer = recording
recording, cache_info = cache_preprocessing(
recording,
mode=params["cache_preprocessing_mode"],
folder=cache_folder,
job_kwargs=job_kwargs,
)

noise_levels = get_noise_levels(recording, return_in_uV=False)
else:
recording = recording_raw.astype("float32")
noise_levels = get_noise_levels(recording, return_in_uV=False)
recording_for_analyzer = recording
cache_info = None

noise_levels = get_noise_levels(
recording, return_in_uV=False, random_slices_kwargs=dict(seed=seed), **job_kwargs
)

# detection
ms_before = params["ms_before"]
ms_after = params["ms_after"]
Expand Down Expand Up @@ -265,20 +278,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print(f"select_peaks(): {len(peaks)} peaks kept for clustering")

num_shifts_merging = int(sampling_frequency * params["merge_similarity_lag_ms"] / 1000.0)

# Clustering
clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params)
clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"]
clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"]
clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"]
clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"]
clustering_kwargs["split"]["split_radius_um"] = params["split_radius_um"]
clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"]
clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"]
clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"]
clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"]
clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"]
clustering_kwargs["merge_from_templates"]["use_lags"] = True
clustering_kwargs["merge_from_templates"]["num_shifts"] = num_shifts_merging
clustering_kwargs["noise_levels"] = noise_levels
clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"]
clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size
clustering_kwargs["seed"] = seed

if params["debug"]:
clustering_kwargs["debug_folder"] = sorter_output_folder
Expand Down Expand Up @@ -353,7 +372,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
spikes = find_spikes_from_templates(
recording,
templates,
method="wobble",
method=params["template_matching_engine"],
method_kwargs={},
pipeline_kwargs=pipeline_kwargs,
job_kwargs=job_kwargs,
Expand All @@ -377,7 +396,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates,
amplitude_scalings=spikes["amplitude"],
noise_levels=noise_levels,
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": params["merge_similarity_lag_ms"]},
sparsity_overlap=0.5,
censor_ms=3.0,
max_distance_um=50,
Expand All @@ -396,6 +415,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
np.save(sorter_output_folder / "spikes.npy", spikes)
templates.to_zarr(sorter_output_folder / "templates.zarr")
if analyzer_final is not None:
analyzer_final._recording = recording_for_analyzer
analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer")

sorting = sorting.save(folder=sorter_output_folder / "sorting")
Expand Down
59 changes: 39 additions & 20 deletions src/spikeinterface/sorters/internal/tridesclous2.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,10 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"clustering_ms_before": 0.5,
"clustering_ms_after": 1.5,
"detection_radius_um": 150.0,
"features_radius_um": 75.0,
"features_radius_um": 120.0,
"split_radius_um": 60.0,
"template_radius_um": 100.0,
"merge_similarity_lag_ms": 0.5,
"freq_min": 150.0,
"freq_max": 6000.0,
"cache_preprocessing_mode": "auto",
Expand All @@ -48,8 +50,8 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"n_svd_components_per_channel": 5,
"n_pca_features": 6,
"clustering_recursive_depth": 3,
"ms_before": 2.0,
"ms_after": 3.0,
"ms_before": 1.0,
"ms_after": 2.5,
"template_sparsify_threshold": 1.5,
"template_min_snr_ptp": 3.5,
"template_max_jitter_ms": 0.2,
Expand All @@ -69,6 +71,10 @@ class Tridesclous2Sorter(ComponentsBasedSorter):
"clustering_ms_before": "Milliseconds before the spike peak for clustering",
"clustering_ms_after": "Milliseconds after the spike peak for clustering",
"radius_um": "Radius for sparsity",
"detection_radius_um": "Radius for peak detection",
"features_radius_um": "Radius for sparsity in SVD features",
"split_radius_um": "Radius for the local split clustering",
"template_radius_um": "Radius for the sparsity of template before template matching",
"freq_min": "Low frequency for bandpass filter",
"freq_max": "High frequency for bandpass filter",
"peak_sign": "Sign of peaks neg/pos/both",
Expand All @@ -94,7 +100,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter):

@classmethod
def get_sorter_version(cls):
return "2025.12"
return "2026.01"

@classmethod
def _run_from_folder(cls, sorter_output_folder, params, verbose):
Expand Down Expand Up @@ -162,6 +168,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
recording = apply_preprocessing_pipeline(recording_raw, params["preprocessing_dict"])
recording = recording.astype("float32")

recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0)

if params["apply_motion_correction"]:
interpolate_motion_kwargs = dict(
border_mode="force_extrapolate",
Expand All @@ -176,35 +184,41 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
**interpolate_motion_kwargs,
)

recording = zscore(recording, dtype="float32")
# whitening is really bad when dirft correction is applied and this changd nothing when no dirft
# recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0)

# Cache in mem or folder
cache_folder = sorter_output_folder / "cache_preprocessing"
recording_for_analyzer = recording
recording, cache_info = cache_preprocessing(
recording,
mode=params["cache_preprocessing_mode"],
folder=cache_folder,
job_kwargs=job_kwargs,
)

noise_levels = np.ones(num_chans, dtype="float32")
else:
recording = recording_raw.astype("float32")
noise_levels = get_noise_levels(recording, return_in_uV=False)
recording_for_analyzer = recording
cache_info = None

recording_for_clustering = recording
noise_levels = get_noise_levels(
recording_for_clustering, return_in_uV=False, random_slices_kwargs=dict(seed=seed), **job_kwargs
)

# detection
detection_params = dict(
noise_levels=noise_levels,
peak_sign=params["peak_sign"],
detect_threshold=params["detect_threshold"],
exclude_sweep_ms=1.5,
radius_um=params["detection_radius_um"],
)

all_peaks = detect_peaks(
recording, method="locally_exclusive", method_kwargs=detection_params, job_kwargs=job_kwargs
# recording,
recording_for_clustering,
method="locally_exclusive",
method_kwargs=detection_params,
job_kwargs=job_kwargs,
)

if verbose:
Expand All @@ -217,33 +231,32 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print(f"select_peaks(): {len(peaks)} peaks kept for clustering")

# if clustering_kwargs["clustering"]["clusterer"] == "isosplit6":
# have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None
# if not have_sisosplit6:
# raise ValueError(
# "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`"
# )

# Clustering
num_shifts_merging = int(sampling_frequency * params["merge_similarity_lag_ms"] / 1000.0)

clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params)
clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"]
clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"]
clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"]
clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"]
clustering_kwargs["split"]["split_radius_um"] = params["split_radius_um"]
clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"]
clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"]
clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"]
clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"]
clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"]
clustering_kwargs["merge_from_templates"]["num_shifts"] = num_shifts_merging
clustering_kwargs["noise_levels"] = noise_levels
clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"]
clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size
clustering_kwargs["seed"] = seed

if params["debug"]:
clustering_kwargs["debug_folder"] = sorter_output_folder

unit_ids, clustering_label, more_outs = find_clusters_from_peaks(
recording,
# recording,
recording_for_clustering,
peaks,
method="iterative-isosplit",
method_kwargs=clustering_kwargs,
Expand All @@ -264,7 +277,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
if verbose:
print(f"find_clusters_from_peaks(): {unit_ids.size} cluster found")

# here the idea was to be able to use other preprocessing for peeler
# but at teh moment it is the same for clustering and peeling
recording_for_peeler = recording
noise_levels = get_noise_levels(
recording_for_peeler, return_in_uV=False, random_slices_kwargs=dict(seed=seed), **job_kwargs
)

# preestimate the sparsity unsing peaks channel
spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True)
Expand Down Expand Up @@ -342,7 +360,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
templates,
amplitude_scalings=spikes["amplitude"],
noise_levels=noise_levels,
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1},
similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": params["merge_similarity_lag_ms"]},
sparsity_overlap=0.5,
censor_ms=3.0,
max_distance_um=50,
Expand All @@ -362,6 +380,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose):
np.save(sorter_output_folder / "spikes.npy", spikes)
templates.to_zarr(sorter_output_folder / "templates.zarr")
if analyzer_final is not None:
analyzer_final._recording = recording_for_analyzer
analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer")

sorting = sorting.save(folder=sorter_output_folder / "sorting")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ class IterativeISOSPLITClustering:
"similarity_metric": "l1",
"num_shifts": 3,
"similarity_thresh": 0.8,
"use_lags": True,
},
"merge_from_features": None,
# "merge_from_features": {"merge_radius_um": 60.0},
Expand Down Expand Up @@ -106,13 +107,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):

ms_before = params["peaks_svd"]["ms_before"]
ms_after = params["peaks_svd"]["ms_after"]
nbefore = int(ms_before * recording.sampling_frequency / 1000.0)
nafter = int(ms_after * recording.sampling_frequency / 1000.0)

# radius_um = params["waveforms"]["radius_um"]
verbose = params["verbose"]

debug_folder = params["debug_folder"]

params_peak_svd = params["peaks_svd"].copy()

params_peak_svd["seed"] = params["seed"]
motion = params_peak_svd["motion"]
motion_aware = motion is not None

Expand Down Expand Up @@ -285,13 +289,17 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()):
post_merge_label1 = post_split_label.copy()

if params["merge_from_templates"] is not None:
params_merge_from_templates = params["merge_from_templates"].copy()
num_shifts = params_merge_from_templates["num_shifts"]
num_shifts = min((num_shifts, nbefore, nafter))
params_merge_from_templates["num_shifts"] = num_shifts
post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates(
peaks,
post_merge_label1,
unit_ids,
templates_array,
template_sparse_mask,
**params["merge_from_templates"],
**params_merge_from_templates,
)
else:
post_merge_label2 = post_merge_label1.copy()
Expand Down