Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Phase-slope index using spectral_connectivity_time instead of spectral_connectivity_epochs #210

Open
wants to merge 20 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ on numpy array inputs.

envelope_correlation
phase_slope_index
phase_slope_index_time
vector_auto_regression
spectral_connectivity_epochs
spectral_connectivity_time
Expand Down
Empty file added file_paths.txt
Empty file.
2 changes: 1 addition & 1 deletion mne_connectivity/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
)
from .datasets import make_signals_in_freq_bands
from .decoding import CoherencyDecomposition
from .effective import phase_slope_index
from .effective import phase_slope_index, phase_slope_index_time
from .envelope import envelope_correlation, symmetric_orth
from .io import read_connectivity
from .spectral import spectral_connectivity_epochs, spectral_connectivity_time
Expand Down
197 changes: 195 additions & 2 deletions mne_connectivity/effective.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,12 @@
import numpy as np
from mne.utils import logger, verbose

from .base import SpectralConnectivity, SpectroTemporalConnectivity
from .spectral import spectral_connectivity_epochs
from .base import (
EpochSpectralConnectivity,
SpectralConnectivity,
SpectroTemporalConnectivity,
)
seqasim marked this conversation as resolved.
Show resolved Hide resolved
from .spectral import spectral_connectivity_epochs, spectral_connectivity_time
from .utils import fill_doc


Expand Down Expand Up @@ -240,3 +244,192 @@ def phase_slope_index(
)

return conn


@verbose
@fill_doc
def phase_slope_index_time(
data,
freqs,
indices=None,
sfreq=2 * np.pi,
mode="cwt_morlet",
fmin=None,
fmax=None,
padding=0,
mt_bandwidth=None,
n_cycles=7,
n_jobs=1,
verbose=None,
):
"""Compute the Phase Slope Index (PSI) connectivity measure across time.

This function computes PSI over time from epoched data. The data may consist of a
single epoch.

The PSI is an effective connectivity measure, i.e., a measure which can give an
indication of the direction of the information flow (causality). For two time
series, one computes the PSI between the first and the second time series as
follows: ::

indices = (np.array([0]), np.array([1]))
psi = phase_slope_index(data, indices=indices, ...)

A positive value means that time series 0 is ahead of time series 1 and a negative
value means the opposite.

The PSI is computed from the coherency (see :func:`spectral_connectivity_time`),
details can be found in :footcite:`NolteEtAl2008`.

Parameters
----------
data : array-like, shape (n_epochs, n_signals, n_times) | Epochs
The data from which to compute connectivity.
freqs : array-like
Array of frequencies of interest for time-frequency decomposition. Only the
frequencies within the range specified by ``fmin`` and ``fmax`` are used.
indices : tuple of array | None
Two arrays with indices of connections for which to compute connectivity. If
`None`, all connections are computed.
sfreq : float
The sampling frequency. Required if data is not :class:`~mne.Epochs`.
mode : str
Time-frequency decomposition method. Can be either: 'multitaper' or
'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and
:func:`mne.time_frequency.tfr_array_morlet` for reference.
fmin : float | tuple of float | None
The lower frequency of interest. Multiple bands are defined using a tuple, e.g.,
``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower bounds. If `None`, the
lowest frequency in ``freqs`` is used.
fmax : float | tuple of float | None
The upper frequency of interest. Multiple bands are defined using a tuple, e.g.
``(13., 30.)`` for two band with 13 Hz and 30 Hz upper bounds. If `None`, the
highest frequency in ``freqs`` is used.
padding : float
Amount of time to consider as padding at the beginning and end of each epoch in
seconds. See Notes of :func:`spectral_connectivity_time` for more information.
mt_bandwidth : float | None
The bandwidth of the multitaper windowing function in Hz. Only used if
``mode='multitaper'``.
n_cycles : float | array-like of float
Number of cycles. Fixed number or one per frequency. Only used if
``mode='cwt_morlet'``.
n_jobs : int
Number of connections to compute in parallel. Memory mapping must be activated.
Please see the Notes section of :func:`spectral_connectivity_time` for details.
%(verbose)s

Returns
-------
conn : instance of EpochSpectralConnectivity
Computed connectivity measure. An instance of
:class:`EpochSpectralConnectivity`. The shape of the connectivity dataset is
``(n_epochs, n_cons, n_bands)``. When ``indices`` is `None`,
``n_cons = n_signals ** 2``. When ``indices`` is specified,
``n_cons = len(indices[0])``.

See Also
--------
mne_connectivity.EpochSpectralConnectivity
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
mne_connectivity.spectral_connectivity_time

References
----------
.. footbibliography::
"""
logger.info("Estimating phase slope index (PSI) across time")

# estimate the coherency
cohy = spectral_connectivity_time(
data,
freqs=freqs,
method="cohy",
average=False,
indices=indices,
sfreq=sfreq,
fmin=fmin,
fmax=fmax,
fskip=0,
faverage=False,
sm_times=0,
sm_freqs=1,
sm_kernel="hanning",
padding=padding,
Comment on lines +354 to +357
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wonder if we should expose more of these options to the user, e.g. why just padding and not the smoothing stuff?

mode=mode,
mt_bandwidth=mt_bandwidth,
n_cycles=n_cycles,
decim=1,
n_jobs=n_jobs,
verbose=verbose,
)

freqs_ = np.array(cohy.freqs)
names = cohy.names
n_tapers = cohy.attrs.get("n_tapers")
n_nodes = cohy.n_nodes
metadata = cohy.metadata
events = cohy.events
event_id = cohy.event_id

logger.info(f"Computing PSI from estimated Coherency: {cohy}")
# compute PSI in the requested bands
if fmin is None:
fmin = -np.inf
if fmax is None:
fmax = np.inf

bands = list(zip(np.asarray((fmin,)).ravel(), np.asarray((fmax,)).ravel()))
n_bands = len(bands)

freq_dim = -1

# allocate space for output
out_shape = list(cohy.shape)
out_shape[freq_dim] = n_bands
psi = np.zeros(out_shape, dtype=np.float64)

# allocate accumulator
acc_shape = copy.copy(out_shape)
acc_shape.pop(freq_dim)
acc = np.empty(acc_shape, dtype=np.complex128)

# create list for frequencies used and frequency bands
# of resulting connectivity data
freqs = list()
freq_bands = list()
idx_fi = [slice(None)] * len(out_shape)
idx_fj = [slice(None)] * len(out_shape)
for band_idx, band in enumerate(bands):
freq_idx = np.where((freqs_ > band[0]) & (freqs_ < band[1]))[0]
freqs.append(freqs_[freq_idx])
freq_bands.append(np.mean(freqs_[freq_idx]))

acc.fill(0.0)
for fi, fj in zip(freq_idx, freq_idx[1:]):
idx_fi[freq_dim] = fi
idx_fj[freq_dim] = fj
acc += (
np.conj(cohy.get_data()[tuple(idx_fi)]) * cohy.get_data()[tuple(idx_fj)]
)

idx_fi[freq_dim] = band_idx
psi[tuple(idx_fi)] = np.imag(acc)
logger.info("[PSI Estimation Done]")

# create a connectivity container
conn = EpochSpectralConnectivity(
data=psi,
names=names,
freqs=freq_bands,
n_nodes=n_nodes,
method="phase-slope-index",
spec_method=mode,
indices=indices,
freqs_computed=freqs,
n_tapers=n_tapers,
metadata=metadata,
events=events,
event_id=event_id,
)

return conn
43 changes: 38 additions & 5 deletions mne_connectivity/spectral/time.py
Original file line number Diff line number Diff line change
Expand Up @@ -577,8 +577,8 @@ def spectral_connectivity_time(
conn = dict()
conn_patterns = dict()
for m in method:
# CaCoh complex-valued, all other methods real-valued
if m == "cacoh":
# Cohy and CaCoh complex-valued, all other methods real-valued
if m in ["cacoh", "cohy"]:
con_scores_dtype = np.complex128
else:
con_scores_dtype = np.float64
Expand Down Expand Up @@ -943,7 +943,7 @@ def _parallel_con(
methods are called, the output is a tuple of lists containing arrays
for the connectivity scores and patterns, respectively.
"""
if "coh" in method:
if ("coh" in method) or ("cohy" in method):
# psd
if weights is not None:
psd = weights * w
Expand Down Expand Up @@ -1035,9 +1035,16 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, faverage, weights):
s_xy = np.squeeze(s_xy, axis=0)
s_xy = _smooth_spectra(s_xy, kernel)
out = []
conn_func = {"plv": _plv, "ciplv": _ciplv, "pli": _pli, "wpli": _wpli, "coh": _coh}
conn_func = {
"plv": _plv,
"ciplv": _ciplv,
"pli": _pli,
"wpli": _wpli,
"coh": _coh,
"cohy": _cohy,
}
for m in method:
if m == "coh":
if m in ["coh", "cohy"]:
s_xx = psd[x]
s_yy = psd[y]
out.append(conn_func[m](s_xx, s_yy, s_xy))
Expand Down Expand Up @@ -1282,6 +1289,32 @@ def _coh(s_xx, s_yy, s_xy):
return coh


def _cohy(s_xx, s_yy, s_xy):
"""Compute coherencey given the cross spectral density and PSD.

Parameters
----------
s_xx : array-like, shape (n_freqs, n_times)
The PSD of channel 'x'.
s_yy : array-like, shape (n_freqs, n_times)
The PSD of channel 'y'.
s_xy : array-like, shape (n_freqs, n_times)
The cross PSD between channel 'x' and channel 'y' across
frequency and time points.

Returns
-------
cohy : array-like, shape (n_freqs, n_times)
The estimated COHY.
"""
con_num = s_xy.mean(axis=-1, keepdims=True)
con_den = np.sqrt(
s_xx.mean(axis=-1, keepdims=True) * s_yy.mean(axis=-1, keepdims=True)
)
cohy = con_num / con_den
return cohy
seqasim marked this conversation as resolved.
Show resolved Hide resolved


def _compute_csd(x, y, weights):
"""Compute cross spectral density between signals x and y."""
if weights is not None:
Expand Down
39 changes: 38 additions & 1 deletion mne_connectivity/tests/test_effective.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import numpy as np
from numpy.testing import assert_array_almost_equal

from mne_connectivity.effective import phase_slope_index
from mne_connectivity.effective import phase_slope_index, phase_slope_index_time


def test_psi():
Expand Down Expand Up @@ -39,3 +39,40 @@ def test_psi():

assert np.all(conn_cwt.get_data() > 0)
assert conn_cwt.shape[-1] == n_times


def test_psi_time():
"""Test Phase Slope Index (PSI) estimation across time."""
tsbinns marked this conversation as resolved.
Show resolved Hide resolved
sfreq = 50.0
n_signals = 3
n_epochs = 10
n_times = 500
rng = np.random.RandomState(42)
data = rng.randn(n_epochs, n_signals, n_times)

# simulate time shifts
for i in range(n_epochs):
data[i, 1, 10:] = data[i, 0, :-10] # signal 0 is ahead
data[i, 2, :-10] = data[i, 0, 10:] # signal 2 is ahead

# conn = phase_slope_index_time(data, mode="fourier", sfreq=sfreq, freqs=np.arange(4))

# assert conn.get_data(output="dense")[1, 0, 0] < 0
# assert conn.get_data(output="dense")[2, 0, 0] > 0

# # only compute for a subset of the indices
indices = (np.array([0]), np.array([1]))
# conn_2 = phase_slope_index_time(data, mode="fourier", sfreq=sfreq, freqs=np.arange(4), indices=indices)

# # the measure is symmetric (sign flip)
# assert_array_almost_equal(
# conn_2.get_data()[0, 0], -conn.get_data(output="dense")[1, 0, 0]
# )

freqs = np.arange(5.0, 20, 0.5)
conn_cwt = phase_slope_index_time(
data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices
)
Comment on lines +58 to +75
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Including some of the things commented here would be nice to better reflect the checks happening for the regular PSI function.

So as well as checking that seed -> target connectivity is > 0, should also check that target -> seed connectivity is < 0 and they are identical but just sign-flipped.


assert np.all(conn_cwt.get_data() > 0)
assert conn_cwt.shape[0] == n_epochs