Skip to content

Commit

Permalink
Merge pull request #1 from INM-6/total_spiking_probability_edges_pep8
Browse files Browse the repository at this point in the history
Total spiking probability edges , fix pep8 issues, refactor unittests
  • Loading branch information
zottelsheep authored Jul 24, 2023
2 parents 7eea5fe + 46709f0 commit 6a1e08d
Show file tree
Hide file tree
Showing 2 changed files with 170 additions and 137 deletions.
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
import itertools
from typing import Iterable, List, NamedTuple, Union, Optional, Tuple
from typing import Iterable, List, NamedTuple, Union, Optional

import numpy as np
from scipy.signal import oaconvolve

from elephant.conversion import BinnedSpikeTrain


def total_spiking_probability_edges(
spike_trains: BinnedSpikeTrain,
surrounding_window_sizes: Optional[List[int]] = None,
Expand All @@ -30,25 +31,37 @@ def total_spiking_probability_edges(
*Background:*
- On an excitatory connection the spikerate increases and decreases again due to the refractory period which results in local maxima in the cross-correlogram followed by downwards slope
- On an excitatory connection the spikerate increases and decreases again
due to the refractory period which results in local maxima in the
cross-correlogram followed by downwards slope
- On an inhibitory connection the spikerate decreases and after refractory period, increases again which results in lokal minima surrounded by high values in the cross-correlogram.
- On an inhibitory connection the spikerate decreases and after refractory
period, increases again which results in lokal minima surrounded by high
values in the cross-correlogram.
- An Edge-Filter can be used to interpret the cross-correlogram and accentuate the lokal Maxima and Minima
- An Edge-Filter can be used to interpret the cross-correlogram and
accentuate the lokal Maxima and Minima
*Procedure:*
1) Compute normalized cross-correlation :math:`NCC` of spiketrains of all Neuronpairs
2) Convolve :math:`NCC` with Edge-Filter :math:`g_{i}` to compute :math:`SPE`
3) Convolve :math:`SPE` with corresponding Running-Total-Filter :math:`h_{i}` to account for different lengths after convolution with Edge-Filter
4) Compute :math:`TSPE` using the sum of all :math:`SPE` for all different filterpairs
5) Compute connectivitymatrix by using the index of the tspe-values with the highest absolute values
1) Compute normalized cross-correlation :math:`NCC` of spiketrains of all
Neuronpairs
2) Convolve :math:`NCC` with Edge-Filter :math:`g_{i}` to compute
:math:`SPE`
3) Convolve :math:`SPE` with corresponding Running-Total-Filter
:math:`h_{i}` to account for different lengths after convolution with
Edge-Filter
4) Compute :math:`TSPE` using the sum of all :math:`SPE` for all different
filterpairs
5) Compute connectivitymatrix by using the index of the tspe-values with
the highest absolute values
*Normalized Cross-Correlation:*
.. math ::
NCC_{XY}(d) = \frac{1}{N} \sum_{i=-\infty}^{\infty}{ \frac{ (y_{(i)} - \bar{y}) \cdot (x_{(i-d)} - \bar{x}) }{ \sigma_x \cdot \sigma_y }}
NCC_{XY}(d) = \frac{1}{N} \sum_{i=-\infty}^{\infty}{ \frac{ (y_{(i)} -
\bar{y}) \cdot (x_{(i-d)} - \bar{x}) }{ \sigma_x \cdot \sigma_y }}
*Spiking Probability Edges*
Expand All @@ -58,7 +71,8 @@ def total_spiking_probability_edges(
*Total Spiking Probability Edges:*
.. math ::
TSPE_{X \rightarrow Y}(d) = \sum_{n=1}^{N_a \cdot N_b \cdot N_c}{SPE_{X \rightarrow Y}^{(n)}(d) * h(i)^{(n)} }
TSPE_{X \rightarrow Y}(d) = \sum_{n=1}^{N_a \cdot N_b \cdot N_c}
{SPE_{X \rightarrow Y}^{(n)}(d) * h(i)^{(n)} }
:cite:`functional_connectivity-de_blasi19_169`
Expand All @@ -67,11 +81,13 @@ def total_spiking_probability_edges(
spike_trains : (N, ) elephant.conversion.BinnedSpikeTrain
A binned spike train containing all neurons for connectivity estimation
surrounding_window_sizes : List[int], default = [3, 4, 5, 6, 7, 8]
Array of window-sizes for the surrounding area of the point of interest.
Array of window-sizes for the surrounding area of the point of
interest.
observed_window_sizes : List[int], default = [2, 3, 4, 5, 6]
Array of window-sizes for the observed area
crossover_window_sizes : List[int], default = [0]
Array of window-sizes for the crossover between surrounding and observed window.
Array of window-sizes for the crossover between surrounding and
observed window.
max_delay : int, default = 25
Defines the max delay when performing the normalized crosscorrelations.
Value depends on the bin-size of the BinnedSpikeTrain.
Expand Down Expand Up @@ -115,7 +131,8 @@ def total_spiking_probability_edges(
if normalize:
for delay_time in delay_times:
NCC_d[:, :, delay_time] /= np.sum(
NCC_d[:, :, delay_time][~np.identity(NCC_d.shape[0], dtype=bool)]
NCC_d[:, :, delay_time][~np.identity(NCC_d.shape[0],
dtype=bool)]
)

# Apply edge and running total filter
Expand All @@ -126,24 +143,29 @@ def total_spiking_probability_edges(
:,
:,
max_padding
- filter.needed_padding : max_delay
- filter.needed_padding: max_delay
+ max_padding
+ filter.needed_padding,
]

# Compute two convolutions with edge- and running total filter
x1 = oaconvolve(
NCC_window, np.expand_dims(filter.edge_filter, (0, 1)), mode="valid", axes=2
NCC_window, np.expand_dims(filter.edge_filter, (0, 1)),
mode="valid", axes=2
)
x2 = oaconvolve(
x1, np.expand_dims(filter.running_total_filter, (0, 1)), mode="full", axes=2
x1, np.expand_dims(filter.running_total_filter, (0, 1)),
mode="full", axes=2
)

tspe_matrix += x2

# Take maxima of absolute of delays to get estimation for connectivity
connectivity_matrix_index = np.argmax(np.abs(tspe_matrix),axis=2,keepdims=True)
connectivity_matrix = np.take_along_axis(tspe_matrix,connectivity_matrix_index,axis=2).squeeze(axis=2)
connectivity_matrix_index = np.argmax(np.abs(tspe_matrix),
axis=2, keepdims=True)
connectivity_matrix = np.take_along_axis(tspe_matrix,
connectivity_matrix_index, axis=2
).squeeze(axis=2)
delay_matrix = connectivity_matrix_index.squeeze()

return connectivity_matrix, delay_matrix
Expand All @@ -153,19 +175,21 @@ def normalized_cross_correlation(
spike_trains: BinnedSpikeTrain,
delay_times: Union[int, List[int], Iterable[int]] = 0,
) -> np.ndarray:
r"""normalized cross correlation using std deviation
r"""
Normalized cross correlation using std deviation
Computes the normalized_cross_correlation between all
Spiketrains inside a BinnedSpikeTrain-Object at a given delay_time
The underlying formula is:
.. math::
NCC_{X\arrY(d)} = \frac{1}{N_{bins}}\sum_{i=-\inf}^{\inf}{\frac{(y_{(i)} - \bar{y}) \cdot (x_{(i-d) - \bar{x})}{\sigma_x \cdot \sigma_y}}}
NCC_{X\arrY(d)} = \frac{1}{N_{bins}}\sum_{i=-\inf}^{\inf}{
\frac{(y_{(i)} - \bar{y}) \cdot (x_{(i-d) - \bar{x})}{\sigma_x
\cdot \sigma_y}}}
The subtraction of mean-values is omitted, since it offers little added
accuracy but increases the compute-time immensely.
"""

n_neurons, n_bins = spike_trains.shape
Expand All @@ -192,7 +216,8 @@ def normalized_cross_correlation(
# Uses theoretical zero-padding for shifted values,
# but since $0 \cdot x = 0$ values can simply be omitted
if delay_time == 0:
CC = spike_trains_array[:, :] @ spike_trains_array[:, :].transpose()
CC = spike_trains_array[:, :] @ spike_trains_array[:, :
].transpose()

elif delay_time > 0:
CC = (
Expand All @@ -206,14 +231,17 @@ def normalized_cross_correlation(
@ spike_trains_array[:, -delay_time:].transpose()
)

# Convert CC to dense matrix before performing the division
CC = CC.toarray()
# Normalize using std deviation
NCC = CC / std_factors / n_bins

# Compute cross correlation at given delay time
NCC_d[index, :, :] = NCC

# Move delay_time axis to back of array
# Makes index using neurons more intuitive → (n_neuron, n_neuron, delay_times)
# Makes index using neurons more intuitive → (n_neuron, n_neuron,
# delay_times)
NCC_d = np.moveaxis(NCC_d, 0, -1)

return NCC_d
Expand Down Expand Up @@ -250,7 +278,8 @@ def generate_edge_filter(
conditions = [
(i > 0) & (i <= surrounding_window_size),
(i > (surrounding_window_size + crossover_window_size))
& (i <= surrounding_window_size + observed_window_size + crossover_window_size),
& (i <= surrounding_window_size + observed_window_size +
crossover_window_size),
(
i
> surrounding_window_size
Expand Down
Loading

0 comments on commit 6a1e08d

Please sign in to comment.