diff --git a/src/spikeinterface/sorters/internal/lupin.py b/src/spikeinterface/sorters/internal/lupin.py index 13578b0342..7bf9c42beb 100644 --- a/src/spikeinterface/sorters/internal/lupin.py +++ b/src/spikeinterface/sorters/internal/lupin.py @@ -42,22 +42,25 @@ class LupinSorter(ComponentsBasedSorter): "clustering_ms_after": 1.3, "whitening_radius_um": 100.0, "detection_radius_um": 50.0, - "features_radius_um": 75.0, + "features_radius_um": 120.0, + "split_radius_um": 60.0, "template_radius_um": 100.0, + "merge_similarity_lag_ms": 0.5, "freq_min": 150.0, "freq_max": 7000.0, "cache_preprocessing_mode": "auto", "peak_sign": "neg", - "detect_threshold": 5, + "detect_threshold": 5.0, "n_peaks_per_channel": 5000, "n_svd_components_per_channel": 5, "n_pca_features": 4, "clustering_recursive_depth": 3, "ms_before": 1.0, "ms_after": 2.5, - "template_sparsify_threshold": 1.5, + "template_sparsify_threshold": 1.0, "template_min_snr_ptp": 4.0, "template_max_jitter_ms": 0.2, + "template_matching_engine": "circus-omp", "min_firing_rate": 0.1, "gather_mode": "memory", "job_kwargs": {}, @@ -74,6 +77,11 @@ class LupinSorter(ComponentsBasedSorter): "clustering_ms_before": "Milliseconds before the spike peak for clustering", "clustering_ms_after": "Milliseconds after the spike peak for clustering", "radius_um": "Radius for sparsity", + "whitening_radius_um": "Radius for whitening", + "detection_radius_um": "Radius for peak detection", + "features_radius_um": "Radius for sparsity in SVD features", + "split_radius_um": "Radius for the local split clustering", + "template_radius_um": "Radius for the sparsity of template before template matching", "freq_min": "Low frequency", "freq_max": "High frequency", "peak_sign": "Sign of peaks neg/pos/both", @@ -99,7 +107,7 @@ class LupinSorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2025.12" + return "2026.01" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): @@ -201,6 +209,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): dtype="float32", mode="local", radius_um=params["whitening_radius_um"], + seed=seed, ) if params["apply_motion_correction"]: @@ -219,6 +228,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): # Cache in mem or folder cache_folder = sorter_output_folder / "cache_preprocessing" + recording_for_analyzer = recording recording, cache_info = cache_preprocessing( recording, mode=params["cache_preprocessing_mode"], @@ -226,12 +236,15 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs=job_kwargs, ) - noise_levels = get_noise_levels(recording, return_in_uV=False) else: recording = recording_raw.astype("float32") - noise_levels = get_noise_levels(recording, return_in_uV=False) + recording_for_analyzer = recording cache_info = None + noise_levels = get_noise_levels( + recording, return_in_uV=False, random_slices_kwargs=dict(seed=seed), **job_kwargs + ) + # detection ms_before = params["ms_before"] ms_after = params["ms_after"] @@ -265,20 +278,26 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print(f"select_peaks(): {len(peaks)} peaks kept for clustering") + num_shifts_merging = int(sampling_frequency * params["merge_similarity_lag_ms"] / 1000.0) + # Clustering clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"] clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"] + clustering_kwargs["split"]["split_radius_um"] = params["split_radius_um"] clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"] clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"] clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"] + clustering_kwargs["merge_from_templates"]["use_lags"] = True + clustering_kwargs["merge_from_templates"]["num_shifts"] = num_shifts_merging clustering_kwargs["noise_levels"] = noise_levels clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size + clustering_kwargs["seed"] = seed if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder @@ -353,7 +372,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): spikes = find_spikes_from_templates( recording, templates, - method="wobble", + method=params["template_matching_engine"], method_kwargs={}, pipeline_kwargs=pipeline_kwargs, job_kwargs=job_kwargs, @@ -377,7 +396,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates, amplitude_scalings=spikes["amplitude"], noise_levels=noise_levels, - similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1}, + similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": params["merge_similarity_lag_ms"]}, sparsity_overlap=0.5, censor_ms=3.0, max_distance_um=50, @@ -396,6 +415,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(sorter_output_folder / "spikes.npy", spikes) templates.to_zarr(sorter_output_folder / "templates.zarr") if analyzer_final is not None: + analyzer_final._recording = recording_for_analyzer analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer") sorting = sorting.save(folder=sorter_output_folder / "sorting") diff --git a/src/spikeinterface/sorters/internal/tridesclous2.py b/src/spikeinterface/sorters/internal/tridesclous2.py index 373305c336..13175e6100 100644 --- a/src/spikeinterface/sorters/internal/tridesclous2.py +++ b/src/spikeinterface/sorters/internal/tridesclous2.py @@ -37,8 +37,10 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "clustering_ms_before": 0.5, "clustering_ms_after": 1.5, "detection_radius_um": 150.0, - "features_radius_um": 75.0, + "features_radius_um": 120.0, + "split_radius_um": 60.0, "template_radius_um": 100.0, + "merge_similarity_lag_ms": 0.5, "freq_min": 150.0, "freq_max": 6000.0, "cache_preprocessing_mode": "auto", @@ -48,8 +50,8 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "n_svd_components_per_channel": 5, "n_pca_features": 6, "clustering_recursive_depth": 3, - "ms_before": 2.0, - "ms_after": 3.0, + "ms_before": 1.0, + "ms_after": 2.5, "template_sparsify_threshold": 1.5, "template_min_snr_ptp": 3.5, "template_max_jitter_ms": 0.2, @@ -69,6 +71,10 @@ class Tridesclous2Sorter(ComponentsBasedSorter): "clustering_ms_before": "Milliseconds before the spike peak for clustering", "clustering_ms_after": "Milliseconds after the spike peak for clustering", "radius_um": "Radius for sparsity", + "detection_radius_um": "Radius for peak detection", + "features_radius_um": "Radius for sparsity in SVD features", + "split_radius_um": "Radius for the local split clustering", + "template_radius_um": "Radius for the sparsity of template before template matching", "freq_min": "Low frequency for bandpass filter", "freq_max": "High frequency for bandpass filter", "peak_sign": "Sign of peaks neg/pos/both", @@ -94,7 +100,7 @@ class Tridesclous2Sorter(ComponentsBasedSorter): @classmethod def get_sorter_version(cls): - return "2025.12" + return "2026.01" @classmethod def _run_from_folder(cls, sorter_output_folder, params, verbose): @@ -162,6 +168,8 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): recording = apply_preprocessing_pipeline(recording_raw, params["preprocessing_dict"]) recording = recording.astype("float32") + recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) + if params["apply_motion_correction"]: interpolate_motion_kwargs = dict( border_mode="force_extrapolate", @@ -176,12 +184,9 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): **interpolate_motion_kwargs, ) - recording = zscore(recording, dtype="float32") - # whitening is really bad when dirft correction is applied and this changd nothing when no dirft - # recording = whiten(recording, dtype="float32", mode="local", radius_um=100.0) - # Cache in mem or folder cache_folder = sorter_output_folder / "cache_preprocessing" + recording_for_analyzer = recording recording, cache_info = cache_preprocessing( recording, mode=params["cache_preprocessing_mode"], @@ -189,14 +194,19 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): job_kwargs=job_kwargs, ) - noise_levels = np.ones(num_chans, dtype="float32") else: recording = recording_raw.astype("float32") - noise_levels = get_noise_levels(recording, return_in_uV=False) + recording_for_analyzer = recording cache_info = None + recording_for_clustering = recording + noise_levels = get_noise_levels( + recording_for_clustering, return_in_uV=False, random_slices_kwargs=dict(seed=seed), **job_kwargs + ) + # detection detection_params = dict( + noise_levels=noise_levels, peak_sign=params["peak_sign"], detect_threshold=params["detect_threshold"], exclude_sweep_ms=1.5, @@ -204,7 +214,11 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): ) all_peaks = detect_peaks( - recording, method="locally_exclusive", method_kwargs=detection_params, job_kwargs=job_kwargs + # recording, + recording_for_clustering, + method="locally_exclusive", + method_kwargs=detection_params, + job_kwargs=job_kwargs, ) if verbose: @@ -217,33 +231,32 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print(f"select_peaks(): {len(peaks)} peaks kept for clustering") - # if clustering_kwargs["clustering"]["clusterer"] == "isosplit6": - # have_sisosplit6 = importlib.util.find_spec("isosplit6") is not None - # if not have_sisosplit6: - # raise ValueError( - # "You want to run tridesclous2 with the isosplit6 (the C++) implementation, but this is not installed, please `pip install isosplit6`" - # ) - # Clustering + num_shifts_merging = int(sampling_frequency * params["merge_similarity_lag_ms"] / 1000.0) + clustering_kwargs = deepcopy(clustering_methods["iterative-isosplit"]._default_params) clustering_kwargs["peaks_svd"]["ms_before"] = params["clustering_ms_before"] clustering_kwargs["peaks_svd"]["ms_after"] = params["clustering_ms_after"] clustering_kwargs["peaks_svd"]["radius_um"] = params["features_radius_um"] clustering_kwargs["peaks_svd"]["n_components"] = params["n_svd_components_per_channel"] + clustering_kwargs["split"]["split_radius_um"] = params["split_radius_um"] clustering_kwargs["split"]["recursive_depth"] = params["clustering_recursive_depth"] clustering_kwargs["split"]["method_kwargs"]["n_pca_features"] = params["n_pca_features"] clustering_kwargs["clean_templates"]["sparsify_threshold"] = params["template_sparsify_threshold"] clustering_kwargs["clean_templates"]["min_snr"] = params["template_min_snr_ptp"] clustering_kwargs["clean_templates"]["max_jitter_ms"] = params["template_max_jitter_ms"] + clustering_kwargs["merge_from_templates"]["num_shifts"] = num_shifts_merging clustering_kwargs["noise_levels"] = noise_levels clustering_kwargs["clean_low_firing"]["min_firing_rate"] = params["min_firing_rate"] clustering_kwargs["clean_low_firing"]["subsampling_factor"] = all_peaks.size / peaks.size + clustering_kwargs["seed"] = seed if params["debug"]: clustering_kwargs["debug_folder"] = sorter_output_folder unit_ids, clustering_label, more_outs = find_clusters_from_peaks( - recording, + # recording, + recording_for_clustering, peaks, method="iterative-isosplit", method_kwargs=clustering_kwargs, @@ -264,7 +277,12 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): if verbose: print(f"find_clusters_from_peaks(): {unit_ids.size} cluster found") + # here the idea was to be able to use other preprocessing for peeler + # but at teh moment it is the same for clustering and peeling recording_for_peeler = recording + noise_levels = get_noise_levels( + recording_for_peeler, return_in_uV=False, random_slices_kwargs=dict(seed=seed), **job_kwargs + ) # preestimate the sparsity unsing peaks channel spike_vector = sorting_pre_peeler.to_spike_vector(concatenated=True) @@ -342,7 +360,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): templates, amplitude_scalings=spikes["amplitude"], noise_levels=noise_levels, - similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": 0.1}, + similarity_kwargs={"method": "l1", "support": "union", "max_lag_ms": params["merge_similarity_lag_ms"]}, sparsity_overlap=0.5, censor_ms=3.0, max_distance_um=50, @@ -362,6 +380,7 @@ def _run_from_folder(cls, sorter_output_folder, params, verbose): np.save(sorter_output_folder / "spikes.npy", spikes) templates.to_zarr(sorter_output_folder / "templates.zarr") if analyzer_final is not None: + analyzer_final._recording = recording_for_analyzer analyzer_final.save_as(format="binary_folder", folder=sorter_output_folder / "analyzer") sorting = sorting.save(folder=sorter_output_folder / "sorting") diff --git a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py index b1d54df80c..7c06502088 100644 --- a/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py +++ b/src/spikeinterface/sortingcomponents/clustering/iterative_isosplit.py @@ -72,6 +72,7 @@ class IterativeISOSPLITClustering: "similarity_metric": "l1", "num_shifts": 3, "similarity_thresh": 0.8, + "use_lags": True, }, "merge_from_features": None, # "merge_from_features": {"merge_radius_um": 60.0}, @@ -106,13 +107,16 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): ms_before = params["peaks_svd"]["ms_before"] ms_after = params["peaks_svd"]["ms_after"] + nbefore = int(ms_before * recording.sampling_frequency / 1000.0) + nafter = int(ms_after * recording.sampling_frequency / 1000.0) + # radius_um = params["waveforms"]["radius_um"] verbose = params["verbose"] debug_folder = params["debug_folder"] params_peak_svd = params["peaks_svd"].copy() - + params_peak_svd["seed"] = params["seed"] motion = params_peak_svd["motion"] motion_aware = motion is not None @@ -285,13 +289,17 @@ def main_function(cls, recording, peaks, params, job_kwargs=dict()): post_merge_label1 = post_split_label.copy() if params["merge_from_templates"] is not None: + params_merge_from_templates = params["merge_from_templates"].copy() + num_shifts = params_merge_from_templates["num_shifts"] + num_shifts = min((num_shifts, nbefore, nafter)) + params_merge_from_templates["num_shifts"] = num_shifts post_merge_label2, templates_array, template_sparse_mask, unit_ids = merge_peak_labels_from_templates( peaks, post_merge_label1, unit_ids, templates_array, template_sparse_mask, - **params["merge_from_templates"], + **params_merge_from_templates, ) else: post_merge_label2 = post_merge_label1.copy()