diff --git a/elephant/functional_connectivity_src/total_spiking_probability_edges.py b/elephant/functional_connectivity_src/total_spiking_probability_edges.py index 1bc342212..b1964aa0b 100644 --- a/elephant/functional_connectivity_src/total_spiking_probability_edges.py +++ b/elephant/functional_connectivity_src/total_spiking_probability_edges.py @@ -1,11 +1,12 @@ import itertools -from typing import Iterable, List, NamedTuple, Union, Optional, Tuple +from typing import Iterable, List, NamedTuple, Union, Optional import numpy as np from scipy.signal import oaconvolve from elephant.conversion import BinnedSpikeTrain + def total_spiking_probability_edges( spike_trains: BinnedSpikeTrain, surrounding_window_sizes: Optional[List[int]] = None, @@ -30,25 +31,37 @@ def total_spiking_probability_edges( *Background:* - - On an excitatory connection the spikerate increases and decreases again due to the refractory period which results in local maxima in the cross-correlogram followed by downwards slope + - On an excitatory connection the spikerate increases and decreases again + due to the refractory period which results in local maxima in the + cross-correlogram followed by downwards slope - - On an inhibitory connection the spikerate decreases and after refractory period, increases again which results in lokal minima surrounded by high values in the cross-correlogram. + - On an inhibitory connection the spikerate decreases and after refractory + period, increases again which results in lokal minima surrounded by high + values in the cross-correlogram. - - An Edge-Filter can be used to interpret the cross-correlogram and accentuate the lokal Maxima and Minima + - An Edge-Filter can be used to interpret the cross-correlogram and + accentuate the lokal Maxima and Minima *Procedure:* - 1) Compute normalized cross-correlation :math:`NCC` of spiketrains of all Neuronpairs - 2) Convolve :math:`NCC` with Edge-Filter :math:`g_{i}` to compute :math:`SPE` - 3) Convolve :math:`SPE` with corresponding Running-Total-Filter :math:`h_{i}` to account for different lengths after convolution with Edge-Filter - 4) Compute :math:`TSPE` using the sum of all :math:`SPE` for all different filterpairs - 5) Compute connectivitymatrix by using the index of the tspe-values with the highest absolute values + 1) Compute normalized cross-correlation :math:`NCC` of spiketrains of all + Neuronpairs + 2) Convolve :math:`NCC` with Edge-Filter :math:`g_{i}` to compute + :math:`SPE` + 3) Convolve :math:`SPE` with corresponding Running-Total-Filter + :math:`h_{i}` to account for different lengths after convolution with + Edge-Filter + 4) Compute :math:`TSPE` using the sum of all :math:`SPE` for all different + filterpairs + 5) Compute connectivitymatrix by using the index of the tspe-values with + the highest absolute values *Normalized Cross-Correlation:* .. math :: - NCC_{XY}(d) = \frac{1}{N} \sum_{i=-\infty}^{\infty}{ \frac{ (y_{(i)} - \bar{y}) \cdot (x_{(i-d)} - \bar{x}) }{ \sigma_x \cdot \sigma_y }} + NCC_{XY}(d) = \frac{1}{N} \sum_{i=-\infty}^{\infty}{ \frac{ (y_{(i)} - + \bar{y}) \cdot (x_{(i-d)} - \bar{x}) }{ \sigma_x \cdot \sigma_y }} *Spiking Probability Edges* @@ -58,7 +71,8 @@ def total_spiking_probability_edges( *Total Spiking Probability Edges:* .. math :: - TSPE_{X \rightarrow Y}(d) = \sum_{n=1}^{N_a \cdot N_b \cdot N_c}{SPE_{X \rightarrow Y}^{(n)}(d) * h(i)^{(n)} } + TSPE_{X \rightarrow Y}(d) = \sum_{n=1}^{N_a \cdot N_b \cdot N_c} + {SPE_{X \rightarrow Y}^{(n)}(d) * h(i)^{(n)} } :cite:`functional_connectivity-de_blasi19_169` @@ -67,11 +81,13 @@ def total_spiking_probability_edges( spike_trains : (N, ) elephant.conversion.BinnedSpikeTrain A binned spike train containing all neurons for connectivity estimation surrounding_window_sizes : List[int], default = [3, 4, 5, 6, 7, 8] - Array of window-sizes for the surrounding area of the point of interest. + Array of window-sizes for the surrounding area of the point of + interest. observed_window_sizes : List[int], default = [2, 3, 4, 5, 6] Array of window-sizes for the observed area crossover_window_sizes : List[int], default = [0] - Array of window-sizes for the crossover between surrounding and observed window. + Array of window-sizes for the crossover between surrounding and + observed window. max_delay : int, default = 25 Defines the max delay when performing the normalized crosscorrelations. Value depends on the bin-size of the BinnedSpikeTrain. @@ -115,7 +131,8 @@ def total_spiking_probability_edges( if normalize: for delay_time in delay_times: NCC_d[:, :, delay_time] /= np.sum( - NCC_d[:, :, delay_time][~np.identity(NCC_d.shape[0], dtype=bool)] + NCC_d[:, :, delay_time][~np.identity(NCC_d.shape[0], + dtype=bool)] ) # Apply edge and running total filter @@ -126,24 +143,29 @@ def total_spiking_probability_edges( :, :, max_padding - - filter.needed_padding : max_delay + - filter.needed_padding: max_delay + max_padding + filter.needed_padding, ] # Compute two convolutions with edge- and running total filter x1 = oaconvolve( - NCC_window, np.expand_dims(filter.edge_filter, (0, 1)), mode="valid", axes=2 + NCC_window, np.expand_dims(filter.edge_filter, (0, 1)), + mode="valid", axes=2 ) x2 = oaconvolve( - x1, np.expand_dims(filter.running_total_filter, (0, 1)), mode="full", axes=2 + x1, np.expand_dims(filter.running_total_filter, (0, 1)), + mode="full", axes=2 ) tspe_matrix += x2 # Take maxima of absolute of delays to get estimation for connectivity - connectivity_matrix_index = np.argmax(np.abs(tspe_matrix),axis=2,keepdims=True) - connectivity_matrix = np.take_along_axis(tspe_matrix,connectivity_matrix_index,axis=2).squeeze(axis=2) + connectivity_matrix_index = np.argmax(np.abs(tspe_matrix), + axis=2, keepdims=True) + connectivity_matrix = np.take_along_axis(tspe_matrix, + connectivity_matrix_index, axis=2 + ).squeeze(axis=2) delay_matrix = connectivity_matrix_index.squeeze() return connectivity_matrix, delay_matrix @@ -153,7 +175,8 @@ def normalized_cross_correlation( spike_trains: BinnedSpikeTrain, delay_times: Union[int, List[int], Iterable[int]] = 0, ) -> np.ndarray: - r"""normalized cross correlation using std deviation + r""" + Normalized cross correlation using std deviation Computes the normalized_cross_correlation between all Spiketrains inside a BinnedSpikeTrain-Object at a given delay_time @@ -161,11 +184,12 @@ def normalized_cross_correlation( The underlying formula is: .. math:: - NCC_{X\arrY(d)} = \frac{1}{N_{bins}}\sum_{i=-\inf}^{\inf}{\frac{(y_{(i)} - \bar{y}) \cdot (x_{(i-d) - \bar{x})}{\sigma_x \cdot \sigma_y}}} + NCC_{X\arrY(d)} = \frac{1}{N_{bins}}\sum_{i=-\inf}^{\inf}{ + \frac{(y_{(i)} - \bar{y}) \cdot (x_{(i-d) - \bar{x})}{\sigma_x + \cdot \sigma_y}}} The subtraction of mean-values is omitted, since it offers little added accuracy but increases the compute-time immensely. - """ n_neurons, n_bins = spike_trains.shape @@ -192,7 +216,8 @@ def normalized_cross_correlation( # Uses theoretical zero-padding for shifted values, # but since $0 \cdot x = 0$ values can simply be omitted if delay_time == 0: - CC = spike_trains_array[:, :] @ spike_trains_array[:, :].transpose() + CC = spike_trains_array[:, :] @ spike_trains_array[:, : + ].transpose() elif delay_time > 0: CC = ( @@ -206,6 +231,8 @@ def normalized_cross_correlation( @ spike_trains_array[:, -delay_time:].transpose() ) + # Convert CC to dense matrix before performing the division + CC = CC.toarray() # Normalize using std deviation NCC = CC / std_factors / n_bins @@ -213,7 +240,8 @@ def normalized_cross_correlation( NCC_d[index, :, :] = NCC # Move delay_time axis to back of array - # Makes index using neurons more intuitive → (n_neuron, n_neuron, delay_times) + # Makes index using neurons more intuitive → (n_neuron, n_neuron, + # delay_times) NCC_d = np.moveaxis(NCC_d, 0, -1) return NCC_d @@ -250,7 +278,8 @@ def generate_edge_filter( conditions = [ (i > 0) & (i <= surrounding_window_size), (i > (surrounding_window_size + crossover_window_size)) - & (i <= surrounding_window_size + observed_window_size + crossover_window_size), + & (i <= surrounding_window_size + observed_window_size + + crossover_window_size), ( i > surrounding_window_size diff --git a/elephant/test/test_total_spiking_probability_edges.py b/elephant/test/test_total_spiking_probability_edges.py index 54ef630f6..60bfa5e2f 100644 --- a/elephant/test/test_total_spiking_probability_edges.py +++ b/elephant/test/test_total_spiking_probability_edges.py @@ -1,88 +1,116 @@ +import unittest from pathlib import Path -from typing import List, Optional, Tuple, Union +from typing import Tuple, Union from neo import SpikeTrain import numpy as np -import pytest from quantities import millisecond as ms from scipy.io import loadmat from elephant.conversion import BinnedSpikeTrain -from elephant.functional_connectivity_src.total_spiking_probability_edges import ( - generate_filter_pairs, - normalized_cross_correlation, - TspeFilterPair, - total_spiking_probability_edges, -) +from elephant.functional_connectivity_src.total_spiking_probability_edges \ + import (generate_filter_pairs, + normalized_cross_correlation, + TspeFilterPair, + total_spiking_probability_edges, + ) from elephant.datasets import download_datasets -def test_generate_filter_pairs(): - a = [1] - b = [1] - c = [1] - test_output = [ - TspeFilterPair( - edge_filter=np.array([-1.0, 0.0, 2.0, 0.0, -1.0]), - running_total_filter=np.array([1.0]), - needed_padding=2, - surrounding_window_size=1, - observed_window_size=1, - crossover_window_size=1, - ) - ] +class TotalSpikingProbabilityEdgesTestCase(unittest.TestCase): + def test_generate_filter_pairs(self): + a = [1] + b = [1] + c = [1] + test_output = [ + TspeFilterPair( + edge_filter=np.array([-1.0, 0.0, 2.0, 0.0, -1.0]), + running_total_filter=np.array([1.0]), + needed_padding=2, + surrounding_window_size=1, + observed_window_size=1, + crossover_window_size=1, + ) + ] - function_output = generate_filter_pairs(a, b, c) + function_output = generate_filter_pairs(a, b, c) - for filter_pair_function, filter_pair_test in zip(function_output, test_output): - assert np.array_equal( - filter_pair_function.edge_filter, filter_pair_test.edge_filter - ) - assert np.array_equal( - filter_pair_function.running_total_filter, - filter_pair_test.running_total_filter, - ) - assert filter_pair_function.needed_padding == filter_pair_test.needed_padding - assert ( - filter_pair_function.surrounding_window_size - == filter_pair_test.surrounding_window_size - ) - assert ( - filter_pair_function.observed_window_size - == filter_pair_test.observed_window_size + for filter_pair_function, filter_pair_test in zip(function_output, + test_output): + np.testing.assert_array_equal( + filter_pair_function.edge_filter, + filter_pair_test.edge_filter) + + np.testing.assert_array_equal( + filter_pair_function.running_total_filter, + filter_pair_test.running_total_filter) + + self.assertEqual(filter_pair_function.needed_padding, + filter_pair_test.needed_padding) + + self.assertEqual(filter_pair_function.surrounding_window_size, + filter_pair_test.surrounding_window_size) + + self.assertEqual(filter_pair_function.observed_window_size, + filter_pair_test.observed_window_size) + + self.assertEqual(filter_pair_function.crossover_window_size, + filter_pair_test.crossover_window_size) + + def test_normalized_cross_correlation(self): + # Generate Spiketrains + delay_time = 5 + spike_times = [3, 4, 5] * ms + spike_times_delayed = spike_times + delay_time * ms + + spiketrains = BinnedSpikeTrain( + [SpikeTrain(spike_times, t_stop=20.0 * ms), + SpikeTrain(spike_times_delayed, t_stop=20.0 * ms),], + bin_size=1 * ms, ) - assert ( - filter_pair_function.crossover_window_size - == filter_pair_test.crossover_window_size + + test_output = np.array([[[0.0, 0.0], [1.1, 0.0]], [[0.0, 1.1], + [0.0, 0.0]]]) + + function_output = normalized_cross_correlation( + spiketrains, [-delay_time, delay_time] ) + assert np.allclose(function_output, test_output, 0.1) + + def test_total_spiking_probability_edges(self): + files = ["SW/new_sim0_100.mat", + "BA/new_sim0_100.mat", + "CA/new_sim0_100.mat", + "ER05/new_sim0_100.mat", + "ER10/new_sim0_100.mat", + "ER15/new_sim0_100.mat", + ] -def test_normalized_cross_correlation(): - # Generate Spiketrains - delay_time = 5 - spike_times = [3, 4, 5] * ms - spike_times_delayed = spike_times + delay_time * ms + for datafile in files: + repo_base_path = 'unittest/functional_connectivity/' \ + 'total_spiking_probability_edges/data/' + downloaded_dataset_path = download_datasets(repo_base_path + + datafile) - spiketrains = BinnedSpikeTrain( - [ - SpikeTrain(spike_times, t_stop=20.0 * ms), - SpikeTrain(spike_times_delayed, t_stop=20.0 * ms), - ], - bin_size=1 * ms, - ) + spiketrains, original_data = load_spike_train_simulated( + downloaded_dataset_path) - test_output = np.array([[[0.0, 0.0], [1.1, 0.0]], [[0.0, 1.1], [0.0, 0.0]]]) + connectivity_matrix, delay_matrix = \ + total_spiking_probability_edges(spiketrains) - function_output = normalized_cross_correlation( - spiketrains, [-delay_time, delay_time] - ) + # Remove self-connections + np.fill_diagonal(connectivity_matrix, 0) - assert np.allclose(function_output, test_output, 0.1) + _, _, _, auc = roc_curve(connectivity_matrix, original_data) + + self.assertGreater(auc, 0.95) # ====== HELPER FUNCTIONS ====== -def classify_connections(connectivity_matrix:np.ndarray,threshold:int): + +def classify_connections(connectivity_matrix: np.ndarray, threshold: int): connectivity_matrix_binarized = connectivity_matrix.copy() mask_excitatory = connectivity_matrix_binarized > threshold @@ -96,6 +124,7 @@ def classify_connections(connectivity_matrix:np.ndarray,threshold:int): return connectivity_matrix_binarized + def confusion_matrix(estimate, original, threshold: int = 1): """ Definition: @@ -104,39 +133,42 @@ def confusion_matrix(estimate, original, threshold: int = 1): - TN: Matches for non-existing synapses are True Negative - FN: mismatches are False Negative. """ - if not np.all(np.isin([-1,0,1], np.unique(estimate))): - estimate = classify_connections(estimate,threshold) - if not np.all(np.isin([-1,0,1], np.unique(original))): - original = classify_connections(original,threshold) + if not np.all(np.isin([-1, 0, 1], np.unique(estimate))): + estimate = classify_connections(estimate, threshold) + if not np.all(np.isin([-1, 0, 1], np.unique(original))): + original = classify_connections(original, threshold) - TP = (np.not_equal(estimate,0) & np.not_equal(original,0)).sum() + TP = (np.not_equal(estimate, 0) & np.not_equal(original, 0)).sum() - TN = (np.equal(estimate,0) & np.equal(original, 0)).sum() + TN = (np.equal(estimate, 0) & np.equal(original, 0)).sum() - FP = (np.not_equal(estimate,0) & np.equal(original, 0)).sum() + FP = (np.not_equal(estimate, 0) & np.equal(original, 0)).sum() - FN = (np.equal(estimate, 0) & np.not_equal(original,0)).sum() + FN = (np.equal(estimate, 0) & np.not_equal(original, 0)).sum() return TP, TN, FP, FN -def fall_out(TP:int, TN:int, FP:int, FN:int): + +def fall_out(TP: int, TN: int, FP: int, FN: int): FPR = FP / (FP + TN) return FPR -def sensitivity(TP:int, TN:int, FP:int, FN:int): + +def sensitivity(TP: int, TN: int, FP: int, FN: int): TPR = TP / (TP + FN) return TPR -def roc_curve(estimate,original): + +def roc_curve(estimate, original): tpr_list = [] fpr_list = [] - max_threshold = max(np.max(np.abs(estimate)),1) + max_threshold = max(np.max(np.abs(estimate)), 1) - thresholds = np.linspace(max_threshold,0,30) + thresholds = np.linspace(max_threshold, 0, 30) for t in thresholds: - conf_matrix = confusion_matrix(estimate,original,threshold=t) + conf_matrix = confusion_matrix(estimate, original, threshold=t) tpr_list.append(sensitivity(*conf_matrix)) fpr_list.append(fall_out(*conf_matrix)) @@ -146,11 +178,9 @@ def roc_curve(estimate,original): return tpr_list, fpr_list, thresholds, auc -def load_spike_train_simulated( - path: Union[Path, str], - bin_size = None, - t_stop = None, -) -> Tuple[BinnedSpikeTrain, np.ndarray]: +def load_spike_train_simulated(path: Union[Path, str], bin_size=None, + t_stop=None, + ) -> Tuple[BinnedSpikeTrain, np.ndarray]: if isinstance(path, str): path = Path(path) @@ -160,11 +190,13 @@ def load_spike_train_simulated( data = loadmat(path, simplify_cells=True)["data"] if "asdf" not in data: - raise ValueError('Incorrect Dataformat: Missing spiketrain_data in "asdf"') + raise ValueError('Incorrect Dataformat: Missing spiketrain_data in' + '"asdf"') spiketrain_data = data["asdf"] - # Get number of electrodesa and recording_duration from last element of data array + # Get number of electrodesa and recording_duration from last element of + # data array n_electrodes, recording_duration_ms = spiketrain_data[-1] recording_duration_ms = recording_duration_ms * ms @@ -174,42 +206,14 @@ def load_spike_train_simulated( spiketrains.append( SpikeTrain( spiketrain_raw * ms, - t_stop= recording_duration_ms, + t_stop=recording_duration_ms, ) ) - spiketrains = BinnedSpikeTrain(spiketrains, bin_size=bin_size, t_stop = t_stop or recording_duration_ms) + spiketrains = BinnedSpikeTrain(spiketrains, bin_size=bin_size, + t_stop=t_stop or recording_duration_ms) # Load original_data original_data = data['SWM'].T return spiketrains, original_data - - -@pytest.mark.parametrize( - "datafile", - [ - "SW/new_sim0_100.mat", - "BA/new_sim0_100.mat", - "CA/new_sim0_100.mat", - "ER05/new_sim0_100.mat", - "ER10/new_sim0_100.mat", - "ER15/new_sim0_100.mat", - ], -) -def test_total_spiking_probability_edges(datafile): - - repo_base_path = 'unittest/functional_connectivity/total_spiking_probability_edges/data/' - downloaded_dataset_path = download_datasets(repo_base_path+datafile) - - spiketrains, original_data = load_spike_train_simulated(downloaded_dataset_path) - - connectivity_matrix, delay_matrix = total_spiking_probability_edges(spiketrains) - - # Remove self-connections - np.fill_diagonal(connectivity_matrix, 0) - - _, _, _, auc = roc_curve(connectivity_matrix, original_data) - - assert auc > 0.95 -