From 212b93c572fe801e96b33399acb6820590421cbf Mon Sep 17 00:00:00 2001 From: Olivier Winter Date: Wed, 31 Dec 2025 10:47:00 +0000 Subject: [PATCH] fix bug spatial filter #4175 --- src/spikeinterface/extractors/cbin_ibl.py | 2 + src/spikeinterface/preprocessing/filter.py | 4 +- .../preprocessing/highpass_spatial_filter.py | 11 ++-- .../tests/test_highpass_spatial_filter.py | 65 +++++++------------ 4 files changed, 33 insertions(+), 49 deletions(-) diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index c70c49e8f8..b296200acc 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -49,6 +49,8 @@ def __init__( ): from neo.rawio.spikeglxrawio import read_meta_file + if Path(folder_path).is_file(): + folder_path = Path(folder_path).parent try: import mtscomp except ImportError: diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 732b310123..f9d337fe8b 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor): def __init__( self, recording, - band=[300.0, 6000.0], + band=(300.0, 6000.0), btype="bandpass", filter_order=5, ftype="butter", @@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f def causal_filter( recording, direction="forward", - band=[300.0, 6000.0], + band=(300.0, 6000.0), btype="bandpass", filter_order=5, ftype="butter", diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index 9228f5de12..a8f054544f 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -207,6 +207,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces.copy() # apply AGC and keep the gains + traces = traces.astype(np.float32) if self.window is not None: traces, agc_gains = agc(traces, window=self.window) else: @@ -255,7 +256,7 @@ def get_traces(self, start_frame, end_frame, channel_indices): # ----------------------------------------------------------------------------------------------- -def agc(traces, window, epsilon=1e-8): +def agc(traces, window, epsilon=None): """ Automatic gain control w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8) @@ -268,13 +269,15 @@ def agc(traces, window, epsilon=1e-8): """ import scipy.signal - gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) + # default value for epsilon is relative to the rms, loosely matching the IBL 1e-8 for an input in Volts + if epsilon is None: + epsilon = np.std(traces - np.mean(traces)) * 0.003 - gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :] + gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) dead_channels = np.sum(gain, axis=0) == 0 - traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels] + traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilon, gain[:, ~dead_channels]) return traces, gain diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 8ff2aea547..a64aab3202 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -3,7 +3,7 @@ import numpy as np from copy import deepcopy -import spikeinterface as si +import spikeinterface.full as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se from spikeinterface.core import generate_recording @@ -24,7 +24,7 @@ @pytest.mark.skipif( - importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, + importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install", ) @pytest.mark.parametrize("lagc", [False, 1, 300]) @@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc): use DEBUG = true to visualise. """ - import spikeglx - import neurodsp.voltage as voltage + import ibldsp.voltage + import neuropixel - options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) - print(options) - - ibl_data, si_recording = get_ibl_si_data() - - si_filtered, _ = run_si_highpass_filter(si_recording, **options) + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") + si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") + si_recording = spre.astype(si_recording, "float") + recording_ps = si.phase_shift(si_recording) + recording_hp = si.highpass_filter(recording_ps, freq_min=300, filter_order=3) + recording_hps = si.highpass_spatial_filter(recording_hp) + raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP + si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP - ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options) + destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1) if DEBUG: - fig, axs = plt.subplots(ncols=4) - axs[0].imshow(si_recording.get_traces(return_in_uV=True)) - axs[0].set_title("SI Raw") - axs[1].imshow(ibl_data.T) - axs[1].set_title("IBL Raw") - axs[2].imshow(si_filtered) - axs[2].set_title("SI Filtered ") - axs[3].imshow(ibl_filtered) - axs[3].set_title("IBL Filtered") + from viewephys.gui import viewephys + + eqc = {} + eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered") + eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered") - assert np.allclose( - si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0 - ) # the differences are entired due to scaling on data load. + np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0) @pytest.mark.parametrize("ntr_pad", [None, 0, 31]) @@ -140,24 +136,6 @@ def test_dtype_stability(dtype): # ---------------------------------------------------------------------------------------------------------------------- -def get_ibl_si_data(): - """ - Set fixture to session to ensure origional data is not changed. - """ - import spikeglx - - local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") - ibl_recording = spikeglx.Reader( - local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True - ) - ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel - - si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") - si_recording = spre.astype(si_recording, dtype="float32") - - return ibl_data, si_recording - - def process_args_for_si(si_recording, lagc): """""" if isinstance(lagc, bool) and not lagc: @@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs, def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs): - butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) + import ibldsp.voltage - ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T + butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) + ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T return ibl_filtered