diff --git a/src/spikeinterface/core/baserecording.py b/src/spikeinterface/core/baserecording.py index 7bb6546664..517662ffb9 100644 --- a/src/spikeinterface/core/baserecording.py +++ b/src/spikeinterface/core/baserecording.py @@ -37,6 +37,8 @@ class BaseRecording(BaseRecordingSnippets): "noise_level_std_scaled", "noise_level_mad_raw", "noise_level_mad_scaled", + "noise_level_rms_raw", + "noise_level_rms_scaled", ] def __init__(self, sampling_frequency: float, channel_ids: list, dtype): diff --git a/src/spikeinterface/extractors/cbin_ibl.py b/src/spikeinterface/extractors/cbin_ibl.py index c70c49e8f8..c7332ef796 100644 --- a/src/spikeinterface/extractors/cbin_ibl.py +++ b/src/spikeinterface/extractors/cbin_ibl.py @@ -7,7 +7,7 @@ import probeinterface from spikeinterface.core import BaseRecording, BaseRecordingSegment -from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts +from spikeinterface.extractors.neuropixels_utils import get_neuropixels_sample_shifts_from_probe from spikeinterface.core.core_tools import define_function_from_class @@ -44,22 +44,13 @@ class CompressedBinaryIblExtractor(BaseRecording): installation_mesg = "To use the CompressedBinaryIblExtractor, install mtscomp: \n\n pip install mtscomp\n\n" - def __init__( - self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None, cbin_file=None - ): + def __init__(self, folder_path=None, load_sync_channel=False, stream_name="ap", cbin_file_path=None): from neo.rawio.spikeglxrawio import read_meta_file try: import mtscomp except ImportError: raise ImportError(self.installation_mesg) - if cbin_file is not None: - warnings.warn( - "The `cbin_file` argument is deprecated and will be removed in version 0.104.0, please use `cbin_file_path` instead", - DeprecationWarning, - stacklevel=2, - ) - cbin_file_path = cbin_file if cbin_file_path is None: folder_path = Path(folder_path) # check bands @@ -124,8 +115,7 @@ def __init__( num_channels_per_adc = 16 else: # NP1.0 num_channels_per_adc = 12 - - sample_shifts = get_neuropixels_sample_shifts(self.get_num_channels(), num_channels_per_adc) + sample_shifts = get_neuropixels_sample_shifts_from_probe(probe, num_channels_per_adc) self.set_property("inter_sample_shift", sample_shifts) self._kwargs = { diff --git a/src/spikeinterface/preprocessing/filter.py b/src/spikeinterface/preprocessing/filter.py index 732b310123..f9d337fe8b 100644 --- a/src/spikeinterface/preprocessing/filter.py +++ b/src/spikeinterface/preprocessing/filter.py @@ -79,7 +79,7 @@ class FilterRecording(BasePreprocessor): def __init__( self, recording, - band=[300.0, 6000.0], + band=(300.0, 6000.0), btype="bandpass", filter_order=5, ftype="butter", @@ -370,7 +370,7 @@ def __init__(self, recording, freq=3000, q=30, margin_ms="auto", dtype=None, **f def causal_filter( recording, direction="forward", - band=[300.0, 6000.0], + band=(300.0, 6000.0), btype="bandpass", filter_order=5, ftype="butter", diff --git a/src/spikeinterface/preprocessing/highpass_spatial_filter.py b/src/spikeinterface/preprocessing/highpass_spatial_filter.py index 9228f5de12..b966883333 100644 --- a/src/spikeinterface/preprocessing/highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/highpass_spatial_filter.py @@ -2,9 +2,9 @@ import numpy as np -from .basepreprocessor import BasePreprocessor, BasePreprocessorSegment -from .filter import fix_dtype -from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin +from spikeinterface.preprocessing.basepreprocessor import BasePreprocessor, BasePreprocessorSegment, BaseRecording +from spikeinterface.preprocessing.filter import fix_dtype +from spikeinterface.core import order_channels_by_depth, get_chunk_with_margin, get_noise_levels from spikeinterface.core.core_tools import define_function_handling_dict_from_class @@ -48,8 +48,17 @@ class HighpassSpatialFilterRecording(BasePreprocessor): Order of spatial butterworth filter highpass_butter_wn : float, default: 0.01 Critical frequency (with respect to Nyquist) of spatial butterworth filter + epsilon : float, default: 0.003 + Value multiplied to RMS values to avoid division by zero during AGC. + random_slice_kwargs : dict | None, default: None + If not None, dictionary of arguments to be passed to `get_noise_levels` when computing + noise levels. dtype : dtype, default: None The dtype of the output traces. If None, the dtype is the same as the input traces + rms_values : np.ndarray | None, default: None + If not None, array of RMS values for each channel to be used during AGC. If None, RMS values are computed + from the recording. This is used to cache pre-computed RMS values, which are only computed once at + initialization. Returns ------- @@ -66,7 +75,7 @@ class HighpassSpatialFilterRecording(BasePreprocessor): def __init__( self, - recording, + recording: BaseRecording, n_channel_pad=60, n_channel_taper=0, direction="y", @@ -74,7 +83,10 @@ def __init__( agc_window_length_s=0.1, highpass_butter_order=3, highpass_butter_wn=0.01, + epsilon=0.003, + random_slice_kwargs=None, dtype=None, + rms_values=None, ): BasePreprocessor.__init__(self, recording) @@ -115,6 +127,14 @@ def __init__( if not apply_agc: agc_window_length_s = None + # Compute or retrieve RMS values + if rms_values is None: + if "noise_level_rms_raw" in recording.get_property_keys(): + rms_values = recording.get_property("noise_level_rms_raw") + else: + random_slice_kwargs = {} if random_slice_kwargs is None else random_slice_kwargs + rms_values = get_noise_levels(recording, method="rms", return_scaled=False, **random_slice_kwargs) + # Pre-compute spatial filtering parameters butter_kwargs = dict(btype="highpass", N=highpass_butter_order, Wn=highpass_butter_wn) sos_filter = scipy.signal.butter(**butter_kwargs, output="sos") @@ -133,6 +153,8 @@ def __init__( order_f, order_r, dtype=dtype, + epsilon=epsilon, + rms_values=rms_values, ) self.add_recording_segment(rec_segment) @@ -145,6 +167,7 @@ def __init__( agc_window_length_s=agc_window_length_s, highpass_butter_order=highpass_butter_order, highpass_butter_wn=highpass_butter_wn, + rms_values=rms_values, ) @@ -161,6 +184,8 @@ def __init__( order_f, order_r, dtype, + epsilon, + rms_values, ): BasePreprocessorSegment.__init__(self, parent_recording_segment) self.parent_recording_segment = parent_recording_segment @@ -185,6 +210,7 @@ def __init__( # get filter params self.sos_filter = sos_filter self.dtype = dtype + self.epsilon_values_for_agc = epsilon * np.array(rms_values) def get_traces(self, start_frame, end_frame, channel_indices): if channel_indices is None: @@ -207,8 +233,9 @@ def get_traces(self, start_frame, end_frame, channel_indices): traces = traces.copy() # apply AGC and keep the gains + traces = traces.astype(np.float32) if self.window is not None: - traces, agc_gains = agc(traces, window=self.window) + traces, agc_gains = agc(traces, window=self.window, epsilons=self.epsilon_values_for_agc) else: agc_gains = None # pad the array with a mirrored version of itself and apply a cosine taper @@ -255,26 +282,35 @@ def get_traces(self, start_frame, end_frame, channel_indices): # ----------------------------------------------------------------------------------------------- -def agc(traces, window, epsilon=1e-8): +def agc(traces, window, epsilons): """ Automatic gain control w_agc, gain = agc(w, window_length=.5, si=.002, epsilon=1e-8) such as w_agc * gain = w - :param traces: seismic array (sample last dimension) - :param window_length: window length (secs) (original default 0.5) - :param si: sampling interval (secs) (original default 0.002) - :param epsilon: whitening (useful mainly for synthetic data) - :return: AGC data array, gain applied to data + + Parameters + ---------- + traces : np.ndarray + Input traces + window : np.ndarray + Window to use for AGC (1D array) + epsilons : np.ndarray[float] + Epsilon values for each channel to avoid division by zero + + Returns + ------- + agc_traces : np.ndarray + AGC applied traces + gain : np.ndarray + Gain applied to the traces """ import scipy.signal gain = scipy.signal.fftconvolve(np.abs(traces), window[:, None], mode="same", axes=0) - gain += (np.sum(gain, axis=0) * epsilon / traces.shape[0])[np.newaxis, :] - dead_channels = np.sum(gain, axis=0) == 0 - traces[:, ~dead_channels] = traces[:, ~dead_channels] / gain[:, ~dead_channels] + traces[:, ~dead_channels] = traces[:, ~dead_channels] / np.maximum(epsilons, gain[:, ~dead_channels]) return traces, gain @@ -282,9 +318,20 @@ def agc(traces, window, epsilon=1e-8): def fcn_extrap(x, f, bounds): """ Extrapolates a flat value before and after bounds - x: array to be filtered - f: function to be applied between bounds (cf. fcn_cosine below) - bounds: 2 elements list or np.array + + Parameters + ---------- + x : np.ndarray + Input array + f : function + Function to be applied between bounds + bounds : list or np.ndarray + 2 elements list or array defining the bounds + + Returns + ------- + y : np.ndarray + Output array """ y = f(x) y[x < bounds[0]] = f(bounds[0]) @@ -298,8 +345,16 @@ def fcn_cosine(bounds): values <= bounds[0]: values values < bounds[0] < bounds[1] : cosine taper values < bounds[1]: bounds[1] - :param bounds: - :return: lambda function + + Parameters + ---------- + bounds : list or np.ndarray + 2 elements list or array defining the bounds + + Returns + ------- + func : function + Lambda function implementing the soft thresholding with cosine taper """ def _cos(x): diff --git a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py index 8ff2aea547..4aa014bbeb 100644 --- a/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py +++ b/src/spikeinterface/preprocessing/tests/test_highpass_spatial_filter.py @@ -3,7 +3,7 @@ import numpy as np from copy import deepcopy -import spikeinterface as si +import spikeinterface.core as si import spikeinterface.preprocessing as spre import spikeinterface.extractors as se from spikeinterface.core import generate_recording @@ -24,7 +24,7 @@ @pytest.mark.skipif( - importlib.util.find_spec("neurodsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, + importlib.util.find_spec("ibldsp") is None or importlib.util.find_spec("spikeglx") is None or ON_GITHUB, reason="Only local. Requires ibl-neuropixel install", ) @pytest.mark.parametrize("lagc", [False, 1, 300]) @@ -51,32 +51,28 @@ def test_highpass_spatial_filter_real_data(lagc): use DEBUG = true to visualise. """ - import spikeglx - import neurodsp.voltage as voltage + import ibldsp.voltage + import neuropixel - options = dict(lagc=lagc, ntr_pad=25, ntr_tap=50, butter_kwargs=None) - print(options) - - ibl_data, si_recording = get_ibl_si_data() - - si_filtered, _ = run_si_highpass_filter(si_recording, **options) + local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") + si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") + si_recording = spre.astype(si_recording, "float") + recording_ps = spre.phase_shift(si_recording) + recording_hp = spre.highpass_filter(recording_ps, freq_min=300, filter_order=3) + recording_hps = spre.highpass_spatial_filter(recording_hp) + raw = si_recording.get_traces().astype(np.float32).T * neuropixel.S2V_AP + si_filtered = recording_hps.get_traces().astype(np.float32).T * neuropixel.S2V_AP - ibl_filtered = run_ibl_highpass_filter(ibl_data.copy(), **options) + destripe = ibldsp.voltage.destripe(raw, fs=30_000, neuropixel_version=1) if DEBUG: - fig, axs = plt.subplots(ncols=4) - axs[0].imshow(si_recording.get_traces(return_in_uV=True)) - axs[0].set_title("SI Raw") - axs[1].imshow(ibl_data.T) - axs[1].set_title("IBL Raw") - axs[2].imshow(si_filtered) - axs[2].set_title("SI Filtered ") - axs[3].imshow(ibl_filtered) - axs[3].set_title("IBL Filtered") + from viewephys.gui import viewephys + + eqc = {} + eqc["si_filtered"] = viewephys(si_filtered, fs=30_000, title="si_filtered") + eqc["ibl_filtered"] = viewephys(destripe, fs=30_000, title="ibl_filtered") - assert np.allclose( - si_filtered, ibl_filtered * 1e6, atol=1e-01, rtol=0 - ) # the differences are entired due to scaling on data load. + np.testing.assert_allclose(si_filtered[12:120, 300:800], destripe[12:120, 300:800], atol=1e-05, rtol=0) @pytest.mark.parametrize("ntr_pad", [None, 0, 31]) @@ -140,24 +136,6 @@ def test_dtype_stability(dtype): # ---------------------------------------------------------------------------------------------------------------------- -def get_ibl_si_data(): - """ - Set fixture to session to ensure origional data is not changed. - """ - import spikeglx - - local_path = si.download_dataset(remote_path="spikeglx/Noise4Sam_g0") - ibl_recording = spikeglx.Reader( - local_path / "Noise4Sam_g0_imec0" / "Noise4Sam_g0_t0.imec0.ap.bin", ignore_warnings=True - ) - ibl_data = ibl_recording.read(slice(None), slice(None), sync=False)[:, :-1].T # cut sync channel - - si_recording = se.read_spikeglx(local_path, stream_id="imec0.ap") - si_recording = spre.astype(si_recording, dtype="float32") - - return ibl_data, si_recording - - def process_args_for_si(si_recording, lagc): """""" if isinstance(lagc, bool) and not lagc: @@ -215,9 +193,10 @@ def run_si_highpass_filter(si_recording, ntr_pad, ntr_tap, lagc, butter_kwargs, def run_ibl_highpass_filter(ibl_data, ntr_pad, ntr_tap, lagc, butter_kwargs): - butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) + import ibldsp.voltage - ibl_filtered = voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T + butter_kwargs, ntr_pad, lagc = process_args_for_ibl(butter_kwargs, ntr_pad, lagc) + ibl_filtered = ibldsp.voltage.kfilt(ibl_data, None, ntr_pad, ntr_tap, lagc, butter_kwargs).T return ibl_filtered