diff --git a/spiketoolkit/postprocessing/postprocessing_tools.py b/spiketoolkit/postprocessing/postprocessing_tools.py index b010b03e..a6795bb8 100644 --- a/spiketoolkit/postprocessing/postprocessing_tools.py +++ b/spiketoolkit/postprocessing/postprocessing_tools.py @@ -8,15 +8,23 @@ from joblib import Parallel, delayed from spikeextractors import RecordingExtractor, SortingExtractor import csv +from tqdm import tqdm +from copy import copy +import time +import os from .utils import update_all_param_dicts_with_kwargs, select_max_channels_from_waveforms, \ - get_max_channels_per_waveforms, select_max_channels_from_templates + divide_recording_into_time_chunks, get_unit_waveforms_for_chunk, get_max_channels_per_waveforms, \ + select_max_channels_from_templates -def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, - return_idxs=False, **kwargs): - ''' +def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, return_idxs=False, chunk_size=None, + chunk_mb=500, **kwargs): + """ Computes the spike waveforms from a recording and sorting extractor. + The recording is split in chunks (the size in Mb is set with the chunk_mb argument) and all waveforms are extracted + for each chunk and then re-assembled. If multiple jobs are used (n_jobs > 1), more and smaller chunks are created + and processed in parallel. Parameters ---------- @@ -30,6 +38,10 @@ def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, List of channels ids to compute waveforms from return_idxs: bool If True, spike indexes and channel indexes are returned + chunk_size: int + Size of chunks in number of samples. If None, it is automatically calculated + chunk_mb: int + Size of chunks in Mb (default 500 Mb) **kwargs: Keyword arguments A dictionary with default values can be retrieved with: st.postprocessing.get_waveforms_params(): @@ -70,7 +82,7 @@ def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, List of spike indexes for which waveforms are computed. Returned if 'return_idxs' is True channel_indexes: list List of max channel indexes - ''' + """ if isinstance(unit_ids, (int, np.integer)): unit_ids = [unit_ids] elif unit_ids is None: @@ -106,6 +118,9 @@ def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, spike_index_list = [] channel_index_list = [] + if max_channels_per_waveforms is None: + max_channels_per_waveforms = len(channel_ids) + if 'waveforms' in sorting.get_shared_unit_spike_feature_names() and not recompute_info: for unit_id in unit_ids: waveforms = sorting.get_unit_spike_features(unit_id, 'waveforms') @@ -126,11 +141,50 @@ def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, dtype = recording.get_dtype() if n_jobs is None: - n_jobs = 0 + n_jobs = 1 + if n_jobs == 0: + n_jobs = 1 + + if seed is not None: + np.random.seed(seed) + + # num_channels = recording.get_num_channels() + num_frames = recording.get_num_frames() + fs = recording.get_sampling_frequency() + n_pad = [int(ms_before * fs / 1000), int(ms_after * fs / 1000)] + + # set chunk size + if chunk_size is not None: + chunk_size = int(chunk_size) + elif chunk_mb is not None: + n_bytes = np.dtype(recording.get_dtype()).itemsize + max_size = int(chunk_mb * 1e6) # set Mb per chunk + chunk_size = max_size // (recording.get_num_channels() * n_bytes) + + if n_jobs > 1: + chunk_size /= n_jobs + + + # chunk_size = num_bytes_per_chunk / num_bytes_per_frame + padding_size = 100 + n_pad[0] + n_pad[1] # a bit excess padding + chunks = divide_recording_into_time_chunks( + num_frames=num_frames, + chunk_size=chunk_size, + padding_size=padding_size + ) + n_chunk = len(chunks) + + if verbose: + print(f"Number of chunks: {len(chunks)} - Number of jobs: {n_jobs}") + + # pre-map memmap files + n_channels = len(channel_ids) + if len(channel_ids) < recording.get_num_channels(): + recording = se.SubRecordingExtractor(recording, channel_ids=channel_ids) if not recording.check_if_dumpable(): if n_jobs > 1: - n_jobs = 0 + n_jobs = 1 print("RecordingExtractor is not dumpable and can't be processed in parallel") rec_arg = recording else: @@ -138,99 +192,178 @@ def get_unit_waveforms(recording, sorting, unit_ids=None, channel_ids=None, rec_arg = recording.dump_to_dict() else: rec_arg = recording - if not sorting.check_if_dumpable(): - if n_jobs > 1: - n_jobs = 0 - print("SortingExtractor is not dumpable and can't be processed in parallel") - sort_arg = sorting + + if memmap: + all_unit_waveforms = [] + for unit_id in unit_ids: + fname = f'waveforms_{unit_id}.raw' + len_wf = len(sorting.get_unit_spike_train(unit_id)) + if max_spikes_per_unit is not None: + if len_wf > max_spikes_per_unit: + len_wf = max_spikes_per_unit + shape = (len_wf, n_channels, sum(n_pad)) + if (sorting.get_tmp_folder() / fname).is_file(): # remove existing files + os.remove(str(sorting.get_tmp_folder() / fname)) + arr = sorting.allocate_array(shape=shape, dtype=dtype, name=fname, memmap=memmap) + all_unit_waveforms.append(arr) else: - if n_jobs > 1: - sort_arg = sorting.dump_to_dict() + all_unit_waveforms = [[] for ii in range(len(unit_ids))] + + if verbose and n_jobs == 1: + chunk_iter = tqdm(range(n_chunk), ascii=True, desc="Extracting waveforms in chunks") + else: + chunk_iter = range(n_chunk) + + # Pre-select spikes to include + spike_times_to_include = [] + if max_spikes_per_unit is not None: + for i, unit in enumerate(unit_ids): + spiketrain = sorting.get_unit_spike_train(unit) + num_spikes = len(spiketrain) + if num_spikes > max_spikes_per_unit: + spike_idxs = np.sort(np.random.permutation(num_spikes)[:max_spikes_per_unit]) + spike_index_list.append(spike_idxs) + spike_times_to_include.append(spiketrain[spike_idxs]) + else: + spike_index_list.append(np.arange(num_spikes)) + spike_times_to_include.append(None) + else: + for u in unit_ids: + spike_index_list.append(None) + spike_times_to_include.append(None) + + # pre-compute spikes for each chunk + times_in_all_chunks = [] + start_spike_idxs = [] + n_spikes = np.zeros(len(unit_ids), dtype='int64') + for chunk in chunks: + times_in_chunk_units = [] + start_spike_idxs.append(copy(n_spikes)) + for i, unit in enumerate(unit_ids): + times = sorting.get_unit_spike_train(unit_id=unit) + times_in_chunk = [] + if spike_times_to_include[i] is not None: + spike_times = spike_times_to_include[i] + spike_time_idxs = np.where((spike_times >= chunk['istart']) + & (spike_times < chunk['iend']))[0] # exclude padding + + if len(spike_time_idxs) > 0: + times_in_chunk = spike_times[spike_time_idxs] + else: + spike_time_idxs = np.where((times >= chunk['istart']) + & (times < chunk['iend']))[0] + times_in_chunk = times[spike_time_idxs] + n_spikes[i] += len(times_in_chunk) + times_in_chunk_units.append(times_in_chunk) + + #assert np.all(times_in_chunk >= chunk['istart']) and np.all(times_in_chunk < chunk['iend']) + times_in_all_chunks.append(times_in_chunk_units) + + # wf_chunk_idxs = np.zeros(len(unit_ids), dtype='int') + + if n_jobs == 1: + for ii in chunk_iter: + # chunk: {istart, iend, istart_with_padding, iend_with_padding} # include padding + unit_waveforms = _extract_waveforms_one_chunk(ii, recording, chunks, unit_ids, n_pad, + times_in_all_chunks, start_spike_idxs, + all_unit_waveforms, memmap, dtype, False) + if not memmap: + for i_unit, unit in enumerate(unit_ids): + wf = unit_waveforms[i_unit] + wf = wf.astype(dtype) + all_unit_waveforms[ii].append(wf) + else: + # waveforms are saved directly to the memmap file if + unit_waveforms, = Parallel(n_jobs=n_jobs, backend=joblib_backend)( + delayed(_extract_waveforms_one_chunk)(ii, rec_arg, chunks, unit_ids, n_pad, + times_in_all_chunks, start_spike_idxs, + all_unit_waveforms, memmap, dtype, verbose,) + for ii in chunk_iter) + + if memmap: + waveform_list = all_unit_waveforms + else: + # concatenate the results over the chunks + if len(chunks) > 1: + waveform_list = [np.concatenate(unit_waveforms[i_unit], axis=0) for i_unit in range(len(unit_ids))] + else: + waveform_list = unit_waveforms + + # return correct max channels + if grouping_property is not None: + if grouping_property not in recording.get_shared_channel_property_names(): + raise ValueError("'grouping_property' should be a property of recording extractors") + if compute_property_from_recording: + compute_sorting_group = True + elif grouping_property not in sorting.get_shared_unit_property_names(): + warnings.warn('Grouping property not in sorting extractor. Computing it from the recording extractor') + compute_sorting_group = True + else: + compute_sorting_group = False + + waveforms_reduced_channels = [] + channel_groups = np.array([recording.get_channel_property(ch, grouping_property) + for ch in recording.get_channel_ids()]) + unit_groups = [] + if compute_sorting_group: + # extract unit groups + for wf in waveform_list: + mean_waveforms = np.squeeze(np.mean(wf, axis=0)) + max_amp_elec = np.unravel_index(mean_waveforms.argmin(), mean_waveforms.shape)[0] + unit_group = recording.get_channel_property(recording.get_channel_ids()[max_amp_elec], + grouping_property) + unit_groups.append(unit_group) else: - sort_arg = sorting + for u in sorting.get_unit_ids(): + unit_group = sorting.get_unit_property(u, grouping_property) + unit_groups.append(unit_group) - fs = recording.get_sampling_frequency() - n_pad = [int(ms_before * fs / 1000), int(ms_after * fs / 1000)] + for (wf, unit_group) in zip(waveform_list, unit_groups): + channel_unit_group = np.where(channel_groups == unit_group)[0] - if n_jobs in [0, 1]: - if memmap: - n_channels = get_max_channels_per_waveforms(recording, grouping_property, channel_ids, - max_channels_per_waveforms) - # pre-construct memmap arrays - for unit_id in unit_ids: - fname = 'waveforms_' + str(unit_id) + '.raw' - len_wf = len(sorting.get_unit_spike_train(unit_id)) - if max_spikes_per_unit is not None: - if len_wf > max_spikes_per_unit: - len_wf = max_spikes_per_unit - shape = (len_wf, n_channels, sum(n_pad)) - arr = sorting.allocate_array(shape=shape, dtype=dtype, name=fname, memmap=memmap) - - waveforms, indexes, max_channel_idxs = _extract_waveforms_one_unit(unit_id, rec_arg, sort_arg, - channel_ids, - unit_ids, grouping_property, - compute_property_from_recording, - max_channels_per_waveforms, - max_spikes_per_unit, n_pad, - dtype, seed, verbose, - memmap_array=arr) - waveform_list.append(waveforms) - spike_index_list.append(indexes) - channel_index_list.append(max_channel_idxs) - else: - for unit_id in unit_ids: - waveforms, indexes, max_channel_idxs = _extract_waveforms_one_unit(unit_id, rec_arg, sort_arg, - channel_ids, unit_ids, - grouping_property, - compute_property_from_recording, - max_channels_per_waveforms, - max_spikes_per_unit, n_pad, - dtype, seed, verbose, - memmap_array=None) - waveform_list.append(waveforms) - spike_index_list.append(indexes) - channel_index_list.append(max_channel_idxs) + if len(channel_unit_group) < max_channels_per_waveforms: + max_channel_idxs = channel_unit_group + else: + subrec = se.SubRecordingExtractor(recording, channel_ids=list(channel_unit_group)) + max_channel_idxs = select_max_channels_from_waveforms(wf, subrec, max_channels_per_waveforms) + + channel_index_list.append(max_channel_idxs) + waveform = wf[:, max_channel_idxs] + # some channels are missing - re-instantiate object + if memmap: + memmap_file = wf.filename + del wf + os.remove(memmap_file) + memmap_array = np.memmap(memmap_file, mode='w+', shape=waveform.shape, + dtype=waveform.dtype) + memmap_array[:] = waveform + del(waveform) + waveforms_reduced_channels.append(memmap_array) + else: + waveforms_reduced_channels.append(waveform) + waveform_list = waveforms_reduced_channels else: - if memmap: - memmap_arrays = [] - n_channels = get_max_channels_per_waveforms(recording, grouping_property, channel_ids, - max_channels_per_waveforms) - # pre-construct memmap arrays - for unit_id in unit_ids: - fname = 'waveforms_' + str(unit_id) + '.raw' - len_wf = len(sorting.get_unit_spike_train(unit_id)) - if max_spikes_per_unit is not None: - if len_wf > max_spikes_per_unit: - len_wf = max_spikes_per_unit - shape = (len_wf, n_channels, sum(n_pad)) - arr = sorting.allocate_array(shape=shape, dtype=dtype, name=fname, memmap=memmap) - memmap_arrays.append(arr) - output_list = Parallel(n_jobs=n_jobs, backend=joblib_backend)( - delayed(_extract_waveforms_one_unit)(unit, rec_arg, sort_arg, channel_ids, - unit_ids, grouping_property, - compute_property_from_recording, - max_channels_per_waveforms, - max_spikes_per_unit, n_pad, - dtype, seed, verbose, mem_array, ) - for (unit, mem_array) in zip(unit_ids, memmap_arrays)) - for i, out in enumerate(output_list): - waveform_list.append(out[0]) - spike_index_list.append(out[1]) - channel_index_list.append(out[2]) + if max_channels_per_waveforms < len(recording.get_channel_ids()): + waveforms_reduced_channels = [] + for wf in waveform_list: + max_channel_idxs = select_max_channels_from_waveforms(wf, recording, max_channels_per_waveforms) + channel_index_list.append(max_channel_idxs) + waveform = wf[:, max_channel_idxs] + # some channels are missing - re-instantiate object + if memmap: + memmap_file = wf.filename + del wf + os.remove(memmap_file) + memmap_array = np.memmap(memmap_file, mode='w+', shape=waveform.shape, + dtype=waveform.dtype) + memmap_array[:] = waveform + waveforms_reduced_channels.append(memmap_array) + else: + waveforms_reduced_channels.append(waveform) + waveform_list = waveforms_reduced_channels else: - output_list = Parallel(n_jobs=n_jobs, backend=joblib_backend)( - delayed(_extract_waveforms_one_unit)(unit, rec_arg, sort_arg, channel_ids, - unit_ids, grouping_property, - compute_property_from_recording, - max_channels_per_waveforms, - max_spikes_per_unit, n_pad, - dtype, seed, verbose, None, ) - for unit in unit_ids) - - for out in output_list: - waveform_list.append(out[0]) - spike_index_list.append(out[1]) - channel_index_list.append(out[2]) + for wf in waveform_list: + channel_index_list.append(channel_ids) if save_property_or_features: for i, unit_id in enumerate(unit_ids): @@ -534,17 +667,16 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret frames_before = int(params_dict['frames_before']) frames_after = int(params_dict['frames_after']) memmap = params_dict['memmap'] - seed = params_dict['seed'] max_spikes_per_unit = params_dict['max_spikes_per_unit'] save_property_or_features = params_dict['save_property_or_features'] recompute_info = params_dict['recompute_info'] - n_jobs = params_dict['n_jobs'] - joblib_backend = params_dict['joblib_backend'] - + ms_before = params_dict['ms_before'] dtype = recording.get_dtype() + assert peak in ['neg', 'pos', 'both'], "'peak' can be 'neg', 'pos', or 'both'" amp_list = [] spike_index_list = [] + center_frame = int(ms_before / 1000 * recording.get_sampling_frequency()) if 'amplitudes' in sorting.get_shared_unit_spike_feature_names() and not recompute_info: for unit_id in unit_ids: amplitudes = sorting.get_unit_spike_features(unit_id, 'amplitudes') @@ -556,84 +688,51 @@ def get_unit_amplitudes(recording, sorting, unit_ids=None, channel_ids=None, ret indexes = np.arange(len(amplitudes)) spike_index_list.append(indexes) else: - if n_jobs is None: - n_jobs = 0 - - if not recording.check_if_dumpable(): - if n_jobs > 1: - n_jobs = 0 - print("RecordingExtractor is not dumpable and can't be processed in parallel") - rec_arg = recording + # pre-construct memmap arrays + if memmap: + for unit_id in unit_ids: + fname = 'waveforms_' + str(unit_id) + '.raw' + len_amp = len(sorting.get_unit_spike_train(unit_id)) + if max_spikes_per_unit is not None: + if len_amp > max_spikes_per_unit: + len_amp = max_spikes_per_unit + shape = len_amp + if (sorting.get_tmp_folder() / fname).is_file(): # remove existing files + os.remove(str(sorting.get_tmp_folder() / fname)) + arr = sorting.allocate_array(shape=shape, dtype=dtype, name=fname, memmap=memmap) + amp_list.append(arr) else: - if n_jobs > 1: - rec_arg = recording.dump_to_dict() - else: - rec_arg = recording - if not sorting.check_if_dumpable(): - if n_jobs > 1: - n_jobs = 0 - print("SortingExtractor is not dumpable and can't be processed in parallel") - sort_arg = sorting - else: - if n_jobs > 1: - sort_arg = sorting.dump_to_dict() - else: - sort_arg = sorting + amp_list = [[] for ii in range(len(unit_ids))] + + waveforms, spike_index_list, channel_index_list = get_unit_waveforms(recording, sorting, unit_ids, channel_ids, + return_idxs=True, **kwargs) + templates = [np.median(wf, 0) for wf in waveforms] + max_channels = [np.unravel_index(np.argmax(np.abs(t)), t.shape)[0] for t in templates] + + for i, (u, wf) in enumerate(zip(unit_ids, waveforms)): + wf_cut = wf[:, max_channels[i], center_frame - frames_before:center_frame + frames_after] + if peak == 'both': + amps = np.max(np.abs(wf_cut), axis=-1) + if len(amps.shape) > 1: + amps = np.max(wf) + elif peak == 'neg': + amps = np.min(wf_cut, axis=-1) + if len(amps.shape) > 1: + amps = np.min(wf, axis=-1) + else: # 'pos' + amps = np.max(wf_cut, axis=-1) + if len(amps.shape) > 1: + amps = np.max(amps, axis=-1) + + if method == 'relative': + amps /= np.median(amps) + amps = amps.astype(dtype) - if n_jobs in [0, 1]: - if memmap: - # pre-construct memmap arrays - for unit_id in unit_ids: - fname = 'amplitudes_' + str(unit_id) + '.raw' - len_amp = len(sorting.get_unit_spike_train(unit_id)) - if max_spikes_per_unit is not None: - if len_amp > max_spikes_per_unit: - len_amp = max_spikes_per_unit - shape = len_amp - arr = sorting.allocate_array(shape=shape, dtype=dtype, name=fname, memmap=memmap) - - amplitudes, indexes = _extract_amplitudes_one_unit(unit_id, rec_arg, sort_arg, channel_ids, - max_spikes_per_unit, frames_before, frames_after, - peak, method, seed, memmap_array=arr) - amp_list.append(amplitudes) - spike_index_list.append(indexes) - else: - for unit_id in unit_ids: - amplitudes, indexes = _extract_amplitudes_one_unit(unit_id, rec_arg, sort_arg, channel_ids, - max_spikes_per_unit, frames_before, frames_after, - peak, method, seed, memmap_array=None) - amp_list.append(amplitudes) - spike_index_list.append(indexes) - else: if memmap: - memmap_arrays = [] - # pre-construct memmap arrays - for unit_id in unit_ids: - fname = 'amplitudes_' + str(unit_id) + '.raw' - len_amp = len(sorting.get_unit_spike_train(unit_id)) - if max_spikes_per_unit is not None: - if len_amp > max_spikes_per_unit: - len_amp = max_spikes_per_unit - shape = len_amp - arr = sorting.allocate_array(shape=shape, dtype=dtype, name=fname, memmap=memmap) - memmap_arrays.append(arr) - output_list = Parallel(n_jobs=n_jobs, backend=joblib_backend)( - delayed(_extract_amplitudes_one_unit)(unit_id, rec_arg, sort_arg, channel_ids, - max_spikes_per_unit, frames_before, frames_after, - peak, method, seed, mem_array, ) - for (unit_id, mem_array) in zip(unit_ids, memmap_arrays)) - for i, out in enumerate(output_list): - amp_list.append(out[0]) - spike_index_list.append(out[1]) + amp_list[i] = amps + del amps else: - output_list = Parallel(n_jobs=n_jobs, backend=joblib_backend)( - delayed(_extract_amplitudes_one_unit)(unit_id, rec_arg, sort_arg, channel_ids, - max_spikes_per_unit, frames_before, frames_after, - peak, method, seed, None, ) - for unit_id in unit_ids) - for i, out in enumerate(output_list): - amp_list.append(out[0]) - spike_index_list.append(out[1]) + amp_list[i] = amps if save_property_or_features: for i, unit_id in enumerate(unit_ids): @@ -707,12 +806,12 @@ def compute_channel_spiking_activity(recording, channel_ids=None, detect_thresho if method == 'simple': if not recording.check_if_dumpable(): if n_jobs > 1: - n_jobs = 0 + n_jobs = 1 print("RecordingExtractor is not dumpable and can't be processedin parallel") else: rec_arg = recording.make_serialized_dict() - if n_jobs in [0, 1]: + if n_jobs == 1: for i, ch in enumerate(channel_ids): if verbose: print(f'Detecting spikes on channel {ch}') @@ -1759,235 +1858,45 @@ def _extract_activity_one_channel(rec_arg, ch, detect_sign, detect_threshold, st return activity -def _extract_waveforms_one_unit(unit, rec_arg, sort_arg, channel_ids, unit_ids, grouping_property, - compute_property_from_recording, max_channels_per_waveforms, max_spikes_per_unit, - n_pad, dtype, seed, verbose, memmap_array=None): - if isinstance(rec_arg, dict): - recording = se.load_extractor_from_dict(rec_arg) - else: - recording = rec_arg - if isinstance(sort_arg, dict): - sorting = se.load_extractor_from_dict(sort_arg) - else: - sorting = sort_arg - - if grouping_property is not None: - if grouping_property not in recording.get_shared_channel_property_names(): - raise ValueError("'grouping_property' should be a property of recording extractors") - if compute_property_from_recording: - compute_sorting_group = True - elif grouping_property not in sorting.get_shared_unit_property_names(): - warnings.warn('Grouping property not in sorting extractor. Computing it from the recording extractor') - compute_sorting_group = True - else: - compute_sorting_group = False - - if not compute_sorting_group: - rec_list, rec_props = recording.get_sub_extractors_by_property(grouping_property, - return_property_list=True) - sort_list, sort_props = sorting.get_sub_extractors_by_property(grouping_property, - return_property_list=True) - if len(rec_props) != len(sort_props): - print('Different' + grouping_property + ' numbers: using largest number of ' + grouping_property) - if len(rec_props) > len(sort_props): - for i_r, rec in enumerate(rec_props): - if rec not in sort_props: - print('Inserting None for property ', rec) - sort_list.insert(i_r, None) - else: - for i_s, sort in enumerate(sort_props): - if sort not in rec_props: - rec_list.insert(i_s, None) - else: - assert len(rec_list) == len(sort_list) - - for i_list, (rec, sort) in enumerate(zip(rec_list, sort_list)): - if sort is not None and rec is not None: - for i, unit_id in enumerate(unit_ids): - if unit == unit_id and unit in sort.get_unit_ids(): - channel_ids = rec.get_channel_ids() - - if max_spikes_per_unit is None: - max_spikes = len(sort.get_unit_spike_train(unit_id)) - else: - max_spikes = max_spikes_per_unit - - if max_channels_per_waveforms is None: - max_channels_per_waveforms = len(channel_ids) - - if verbose: - print('Waveform ' + str(i + 1) + '/' + str(len(unit_ids))) - wf, indexes = _get_random_spike_waveforms(recording=rec, - sorting=sort, - unit=unit_id, - max_spikes_per_unit=max_spikes, - snippet_len=n_pad, - channel_ids=channel_ids, - seed=seed) - wf = wf.astype(dtype) - if max_channels_per_waveforms < len(channel_ids): - max_channel_idxs = select_max_channels_from_waveforms(wf, rec, - max_channels_per_waveforms) - else: - max_channel_idxs = np.arange(rec.get_num_channels()) - wf = wf[:, max_channel_idxs] - - if memmap_array is None: - waveforms = wf - else: - if np.all(wf.shape == memmap_array.shape): - memmap_array[:] = wf - else: - # some channels are missing - re-instantiate object - memmap_file = memmap_array.filename - del memmap_array - memmap_array = np.memmap(memmap_file, mode='w+', shape=wf.shape, dtype=wf.dtype) - memmap_array[:] = wf - waveforms = memmap_array - return waveforms, list(indexes), list(max_channel_idxs) - else: - for i, unit_id in enumerate(unit_ids): - if unit == unit_id: - if channel_ids is None: - channel_ids = recording.get_channel_ids() - - rec = se.SubRecordingExtractor(recording, channel_ids=channel_ids) - rec_groups = np.array(rec.get_channel_groups()) - groups, count = np.unique(rec_groups, return_counts=True) - if max_channels_per_waveforms is None: - max_channels_per_waveforms = np.max(count) - elif max_channels_per_waveforms >= np.max(count): - max_channels_per_waveforms = np.max(count) - - if max_spikes_per_unit is None: - max_spikes = len(sorting.get_unit_spike_train(unit_id)) - else: - max_spikes = max_spikes_per_unit +def _extract_waveforms_one_chunk(i, rec_arg, chunks, unit_ids, n_pad, times_in_chunk, cumulative_n_spikes, + waveforms_file, memmap, dtype, verbose): + chunk = chunks[i] + times_this_chunk = times_in_chunk[i] + n_spikes = cumulative_n_spikes[i] - if verbose: - print('Waveform ' + str(i + 1) + '/' + str(len(unit_ids))) - wf, indexes = _get_random_spike_waveforms(recording=recording, - sorting=sorting, - unit=unit_id, - max_spikes_per_unit=max_spikes, - snippet_len=n_pad, - channel_ids=channel_ids, - seed=seed) - wf = wf.astype(dtype) - mean_waveforms = np.squeeze(np.mean(wf, axis=0)) - max_amp_elec = np.unravel_index(mean_waveforms.argmin(), mean_waveforms.shape)[0] - group = recording.get_channel_property(recording.get_channel_ids()[max_amp_elec], grouping_property) - - elec_group = np.where(rec_groups == group)[0] - wf = wf[:, elec_group, :] - if max_channels_per_waveforms < len(elec_group): - max_channel_idxs = select_max_channels_from_waveforms(wf, rec, max_channels_per_waveforms) - else: - max_channel_idxs = np.arange(len(elec_group)) - wf = wf[:, max_channel_idxs] - - if memmap_array is None: - waveforms = wf - else: - if np.all(wf.shape == memmap_array.shape): - memmap_array[:] = wf - else: - # some channels are missing - re-instantiate object - memmap_file = memmap_array.filename - del memmap_array - memmap_array = np.memmap(memmap_file, mode='w+', shape=wf.shape, dtype=wf.dtype) - memmap_array[:] = wf - waveforms = memmap_array - return waveforms, list(indexes), list(max_channel_idxs), - - else: - for i, unit_id in enumerate(unit_ids): - if unit == unit_id: - if channel_ids is None: - channel_ids = recording.get_channel_ids() - - if max_channels_per_waveforms is None: - max_channels_per_waveforms = len(channel_ids) - - if max_spikes_per_unit is None: - max_spikes = len(sorting.get_unit_spike_train(unit_id)) - else: - max_spikes = max_spikes_per_unit - - if verbose: - print('Waveform ' + str(i + 1) + '/' + str(len(unit_ids))) - - wf, indexes = _get_random_spike_waveforms(recording=recording, - sorting=sorting, - unit=unit_id, - max_spikes_per_unit=max_spikes, - snippet_len=n_pad, - channel_ids=channel_ids, - seed=seed) - wf = wf.astype(dtype) - if max_channels_per_waveforms < len(channel_ids): - max_channel_idxs = select_max_channels_from_waveforms(wf, recording, max_channels_per_waveforms) - else: - max_channel_idxs = np.arange(len(channel_ids)) - wf = wf[:, max_channel_idxs] - - if memmap_array is None: - waveforms = wf - else: - if np.all(wf.shape == memmap_array.shape): - memmap_array[:] = wf - else: - # some channels are missing - re-instantiate object - memmap_file = memmap_array.filename - del memmap_array - memmap_array = np.memmap(memmap_file, mode='w+', shape=wf.shape, dtype=wf.dtype) - memmap_array[:] = wf - waveforms = memmap_array - return waveforms, list(indexes), list(max_channel_idxs), - - -def _extract_amplitudes_one_unit(unit, rec_arg, sort_arg, channel_ids, max_spikes_per_unit, frames_before, frames_after, - peak, method, seed, memmap_array=None): + if verbose: + print(f"Chunk {i+1}: extracting waveforms") if isinstance(rec_arg, dict): recording = se.load_extractor_from_dict(rec_arg) else: recording = rec_arg - if isinstance(sort_arg, dict): - sorting = se.load_extractor_from_dict(sort_arg) - else: - sorting = sort_arg - - spike_train = sorting.get_unit_spike_train(unit) - if max_spikes_per_unit < len(spike_train): - indexes = np.sort(np.random.RandomState(seed=seed).permutation(len(spike_train))[:max_spikes_per_unit]) - else: - indexes = np.arange(len(spike_train)) - spike_train = spike_train[indexes] - - snippets = recording.get_snippets(reference_frames=spike_train, - snippet_len=[frames_before, frames_after], channel_ids=channel_ids) - if peak == 'both': - amps = np.max(np.abs(snippets), axis=-1) - if len(amps.shape) > 1: - amps = np.max(amps, axis=-1) - elif peak == 'neg': - amps = np.min(snippets, axis=-1) - if len(amps.shape) > 1: - amps = np.min(amps, axis=-1) - elif peak == 'pos': - amps = np.max(snippets, axis=-1) - if len(amps.shape) > 1: - amps = np.max(amps, axis=-1) - else: - raise Exception("'peak' can be 'neg', 'pos', or 'both'") + t_start = time.perf_counter() + # chunk: {istart, iend, istart_with_padding, iend_with_padding} # include padding + recording_chunk = se.SubRecordingExtractor( + parent_recording=recording, + start_frame=chunk['istart_with_padding'], + end_frame=chunk['iend_with_padding'] + ) + + # num_events_in_chunk x num_channels_in_nbhd[unit_id] x len_of_one_snippet + unit_waveforms = get_unit_waveforms_for_chunk( + recording=recording_chunk, + chunk=chunk, + unit_ids=unit_ids, + snippet_len=n_pad, + times_in_chunk=times_this_chunk + ) + t_stop = time.perf_counter() + if verbose: + print(f"Chunk {i+1}: waveforms extracted in {t_stop - t_start}s") - if method == 'relative': - amps /= np.median(amps) + if memmap: + for i_unit, unit in enumerate(unit_ids): + wf = unit_waveforms[i_unit] + wf = wf.astype(dtype) - if memmap_array is None: - amplitudes = amps + if len(wf) > 0: + waveforms_file[i_unit][n_spikes[i_unit]:n_spikes[i_unit] + len(wf)] = wf + return None else: - memmap_array[:] = amps - amplitudes = memmap_array - - return amplitudes, list(indexes), + return unit_waveforms diff --git a/spiketoolkit/postprocessing/utils.py b/spiketoolkit/postprocessing/utils.py index 4925f750..119b547a 100644 --- a/spiketoolkit/postprocessing/utils.py +++ b/spiketoolkit/postprocessing/utils.py @@ -124,3 +124,65 @@ def get_max_channels_per_waveforms(recording, grouping_property, channel_ids, ma else: n_channels = max_channels_per_waveforms return n_channels + + +def extract_snippet_from_traces( + traces, + start_frame, + end_frame, +): + if (0 <= start_frame) and (end_frame <= traces.shape[1]): + x = traces[:, start_frame:end_frame] + else: + # handle edge cases + x = np.zeros((traces.shape[0], end_frame - start_frame), dtype=traces.dtype) + i1 = int(max(0, start_frame)) + i2 = int(min(traces.shape[1], end_frame)) + x[:, (i1 - start_frame):(i2 - start_frame)] = traces[:, i1:i2] + return x + + +def get_unit_waveforms_for_chunk( + recording, + chunk, + unit_ids, + snippet_len, + times_in_chunk, +): + # chunks are chosen small enough so that all traces can be loaded into memory + traces = recording.get_traces() + frame_offset = chunk['istart'] - chunk['istart_with_padding'] + + unit_waveforms = [] + for i_unit, unit_id in enumerate(unit_ids): + # find indexes in chunk + if len(times_in_chunk[i_unit]) > 0: + # Adjust time with padding + try: + snippets = [extract_snippet_from_traces(traces, + start_frame=frame_offset + int(t) - snippet_len[0], + end_frame=frame_offset + int(t) + snippet_len[1]) + for t in times_in_chunk[i_unit] - chunk['istart']] + except: + raise Exception + unit_waveforms.append(np.stack(snippets)) + else: + unit_waveforms.append(np.zeros((0, recording.get_num_channels(), + snippet_len[0] + snippet_len[1]), dtype=traces.dtype)) + + return unit_waveforms + + +def divide_recording_into_time_chunks(num_frames, chunk_size, padding_size): + chunks = [] + ii = 0 + while ii < num_frames: + ii2 = int(min(ii + chunk_size, num_frames)) + chunks.append(dict( + istart=ii, + iend=ii2, + istart_with_padding=int(max(0, ii - padding_size)), + iend_with_padding=int(min(num_frames, ii2 + padding_size)) + )) + ii = ii2 + return chunks diff --git a/spiketoolkit/preprocessing/bandpass_filter.py b/spiketoolkit/preprocessing/bandpass_filter.py index 8d95ccce..999e5b45 100644 --- a/spiketoolkit/preprocessing/bandpass_filter.py +++ b/spiketoolkit/preprocessing/bandpass_filter.py @@ -91,7 +91,7 @@ def _create_filter_kernel(N, sampling_frequency, freq_min, freq_max, freq_wid=10 def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filter_type='fft', order=3, - chunk_size=30000, cache_to_file=False, cache_chunks=False): + chunk_size=30000, cache_to_file=False, cache_chunks=False, dtype=None): ''' Performs a lazy filter on the recording extractor traces. @@ -116,6 +116,8 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte If True, filtered traces are computed and cached all at once on disk in temp file cache_chunks: bool (default False). If True then each chunk is cached in memory (in a dict) + dtype: dtype + The dtype of the traces Returns ------- @@ -133,7 +135,8 @@ def bandpass_filter(recording, freq_min=300, freq_max=6000, freq_wid=1000, filte filter_type=filter_type, order=order, chunk_size=chunk_size, - cache_chunks=cache_chunks + cache_chunks=cache_chunks, + dtype=dtype ) if cache_to_file: return se.CacheRecordingExtractor(bpf_recording, chunk_size=chunk_size) diff --git a/spiketoolkit/tests/test_postprocessing.py b/spiketoolkit/tests/test_postprocessing.py index c50e4345..39fe53bb 100644 --- a/spiketoolkit/tests/test_postprocessing.py +++ b/spiketoolkit/tests/test_postprocessing.py @@ -40,6 +40,14 @@ def test_waveforms(): assert np.allclose(w, w_gt) assert 'waveforms' not in sort.get_shared_unit_spike_feature_names() + # small chunks + wav = get_unit_waveforms(rec, sort, ms_before=ms_cut, ms_after=ms_cut, save_property_or_features=False, + n_jobs=n, memmap=m, chunk_mb=5, recompute_info=True) + + for (w, w_gt) in zip(wav, waveforms): + assert np.allclose(w, w_gt) + assert 'waveforms' not in sort.get_shared_unit_spike_feature_names() + # change cut ms wav = get_unit_waveforms(rec, sort, ms_before=2, ms_after=2, save_property_or_features=True, n_jobs=n, memmap=m, recompute_info=True) @@ -299,4 +307,4 @@ def test_compute_pca_scores(): if __name__ == '__main__': - test_waveforms() + test_export_to_phy()