diff --git a/doc/api.rst b/doc/api.rst index 81c84418..a43cfdfc 100644 --- a/doc/api.rst +++ b/doc/api.rst @@ -45,6 +45,7 @@ on numpy array inputs. envelope_correlation phase_slope_index + phase_slope_index_time vector_auto_regression spectral_connectivity_epochs spectral_connectivity_time diff --git a/file_paths.txt b/file_paths.txt new file mode 100644 index 00000000..e69de29b diff --git a/mne_connectivity/__init__.py b/mne_connectivity/__init__.py index ce18a284..25c37490 100644 --- a/mne_connectivity/__init__.py +++ b/mne_connectivity/__init__.py @@ -25,7 +25,7 @@ ) from .datasets import make_signals_in_freq_bands from .decoding import CoherencyDecomposition -from .effective import phase_slope_index +from .effective import phase_slope_index, phase_slope_index_time from .envelope import envelope_correlation, symmetric_orth from .io import read_connectivity from .spectral import spectral_connectivity_epochs, spectral_connectivity_time diff --git a/mne_connectivity/effective.py b/mne_connectivity/effective.py index ab598418..307e62ce 100644 --- a/mne_connectivity/effective.py +++ b/mne_connectivity/effective.py @@ -7,8 +7,12 @@ import numpy as np from mne.utils import logger, verbose -from .base import SpectralConnectivity, SpectroTemporalConnectivity -from .spectral import spectral_connectivity_epochs +from .base import ( + EpochSpectralConnectivity, + SpectralConnectivity, + SpectroTemporalConnectivity, +) +from .spectral import spectral_connectivity_epochs, spectral_connectivity_time from .utils import fill_doc @@ -240,3 +244,192 @@ def phase_slope_index( ) return conn + + +@verbose +@fill_doc +def phase_slope_index_time( + data, + freqs, + indices=None, + sfreq=2 * np.pi, + mode="cwt_morlet", + fmin=None, + fmax=None, + padding=0, + mt_bandwidth=None, + n_cycles=7, + n_jobs=1, + verbose=None, +): + """Compute the Phase Slope Index (PSI) connectivity measure across time. + + This function computes PSI over time from epoched data. The data may consist of a + single epoch. + + The PSI is an effective connectivity measure, i.e., a measure which can give an + indication of the direction of the information flow (causality). For two time + series, one computes the PSI between the first and the second time series as + follows: :: + + indices = (np.array([0]), np.array([1])) + psi = phase_slope_index(data, indices=indices, ...) + + A positive value means that time series 0 is ahead of time series 1 and a negative + value means the opposite. + + The PSI is computed from the coherency (see :func:`spectral_connectivity_time`), + details can be found in :footcite:`NolteEtAl2008`. + + Parameters + ---------- + data : array-like, shape (n_epochs, n_signals, n_times) | Epochs + The data from which to compute connectivity. + freqs : array-like + Array of frequencies of interest for time-frequency decomposition. Only the + frequencies within the range specified by ``fmin`` and ``fmax`` are used. + indices : tuple of array | None + Two arrays with indices of connections for which to compute connectivity. If + `None`, all connections are computed. + sfreq : float + The sampling frequency. Required if data is not :class:`~mne.Epochs`. + mode : str + Time-frequency decomposition method. Can be either: 'multitaper' or + 'cwt_morlet'. See :func:`mne.time_frequency.tfr_array_multitaper` and + :func:`mne.time_frequency.tfr_array_morlet` for reference. + fmin : float | tuple of float | None + The lower frequency of interest. Multiple bands are defined using a tuple, e.g., + ``(8., 20.)`` for two bands with 8 Hz and 20 Hz lower bounds. If `None`, the + lowest frequency in ``freqs`` is used. + fmax : float | tuple of float | None + The upper frequency of interest. Multiple bands are defined using a tuple, e.g. + ``(13., 30.)`` for two band with 13 Hz and 30 Hz upper bounds. If `None`, the + highest frequency in ``freqs`` is used. + padding : float + Amount of time to consider as padding at the beginning and end of each epoch in + seconds. See Notes of :func:`spectral_connectivity_time` for more information. + mt_bandwidth : float | None + The bandwidth of the multitaper windowing function in Hz. Only used if + ``mode='multitaper'``. + n_cycles : float | array-like of float + Number of cycles. Fixed number or one per frequency. Only used if + ``mode='cwt_morlet'``. + n_jobs : int + Number of connections to compute in parallel. Memory mapping must be activated. + Please see the Notes section of :func:`spectral_connectivity_time` for details. + %(verbose)s + + Returns + ------- + conn : instance of EpochSpectralConnectivity + Computed connectivity measure. An instance of + :class:`EpochSpectralConnectivity`. The shape of the connectivity dataset is + ``(n_epochs, n_cons, n_bands)``. When ``indices`` is `None`, + ``n_cons = n_signals ** 2``. When ``indices`` is specified, + ``n_cons = len(indices[0])``. + + See Also + -------- + mne_connectivity.EpochSpectralConnectivity + mne_connectivity.spectral_connectivity_time + + References + ---------- + .. footbibliography:: + """ + logger.info("Estimating phase slope index (PSI) across time") + + # estimate the coherency + cohy = spectral_connectivity_time( + data, + freqs=freqs, + method="cohy", + average=False, + indices=indices, + sfreq=sfreq, + fmin=fmin, + fmax=fmax, + fskip=0, + faverage=False, + sm_times=0, + sm_freqs=1, + sm_kernel="hanning", + padding=padding, + mode=mode, + mt_bandwidth=mt_bandwidth, + n_cycles=n_cycles, + decim=1, + n_jobs=n_jobs, + verbose=verbose, + ) + + freqs_ = np.array(cohy.freqs) + names = cohy.names + n_tapers = cohy.attrs.get("n_tapers") + n_nodes = cohy.n_nodes + metadata = cohy.metadata + events = cohy.events + event_id = cohy.event_id + + logger.info(f"Computing PSI from estimated Coherency: {cohy}") + # compute PSI in the requested bands + if fmin is None: + fmin = -np.inf + if fmax is None: + fmax = np.inf + + bands = list(zip(np.asarray((fmin,)).ravel(), np.asarray((fmax,)).ravel())) + n_bands = len(bands) + + freq_dim = -1 + + # allocate space for output + out_shape = list(cohy.shape) + out_shape[freq_dim] = n_bands + psi = np.zeros(out_shape, dtype=np.float64) + + # allocate accumulator + acc_shape = copy.copy(out_shape) + acc_shape.pop(freq_dim) + acc = np.empty(acc_shape, dtype=np.complex128) + + # create list for frequencies used and frequency bands + # of resulting connectivity data + freqs = list() + freq_bands = list() + idx_fi = [slice(None)] * len(out_shape) + idx_fj = [slice(None)] * len(out_shape) + for band_idx, band in enumerate(bands): + freq_idx = np.where((freqs_ > band[0]) & (freqs_ < band[1]))[0] + freqs.append(freqs_[freq_idx]) + freq_bands.append(np.mean(freqs_[freq_idx])) + + acc.fill(0.0) + for fi, fj in zip(freq_idx, freq_idx[1:]): + idx_fi[freq_dim] = fi + idx_fj[freq_dim] = fj + acc += ( + np.conj(cohy.get_data()[tuple(idx_fi)]) * cohy.get_data()[tuple(idx_fj)] + ) + + idx_fi[freq_dim] = band_idx + psi[tuple(idx_fi)] = np.imag(acc) + logger.info("[PSI Estimation Done]") + + # create a connectivity container + conn = EpochSpectralConnectivity( + data=psi, + names=names, + freqs=freq_bands, + n_nodes=n_nodes, + method="phase-slope-index", + spec_method=mode, + indices=indices, + freqs_computed=freqs, + n_tapers=n_tapers, + metadata=metadata, + events=events, + event_id=event_id, + ) + + return conn diff --git a/mne_connectivity/spectral/time.py b/mne_connectivity/spectral/time.py index a8bcdfb6..7af80215 100644 --- a/mne_connectivity/spectral/time.py +++ b/mne_connectivity/spectral/time.py @@ -577,8 +577,8 @@ def spectral_connectivity_time( conn = dict() conn_patterns = dict() for m in method: - # CaCoh complex-valued, all other methods real-valued - if m == "cacoh": + # Cohy and CaCoh complex-valued, all other methods real-valued + if m in ["cacoh", "cohy"]: con_scores_dtype = np.complex128 else: con_scores_dtype = np.float64 @@ -943,7 +943,7 @@ def _parallel_con( methods are called, the output is a tuple of lists containing arrays for the connectivity scores and patterns, respectively. """ - if "coh" in method: + if ("coh" in method) or ("cohy" in method): # psd if weights is not None: psd = weights * w @@ -1035,9 +1035,16 @@ def _pairwise_con(w, psd, x, y, method, kernel, foi_idx, faverage, weights): s_xy = np.squeeze(s_xy, axis=0) s_xy = _smooth_spectra(s_xy, kernel) out = [] - conn_func = {"plv": _plv, "ciplv": _ciplv, "pli": _pli, "wpli": _wpli, "coh": _coh} + conn_func = { + "plv": _plv, + "ciplv": _ciplv, + "pli": _pli, + "wpli": _wpli, + "coh": _coh, + "cohy": _cohy, + } for m in method: - if m == "coh": + if m in ["coh", "cohy"]: s_xx = psd[x] s_yy = psd[y] out.append(conn_func[m](s_xx, s_yy, s_xy)) @@ -1282,6 +1289,32 @@ def _coh(s_xx, s_yy, s_xy): return coh +def _cohy(s_xx, s_yy, s_xy): + """Compute coherencey given the cross spectral density and PSD. + + Parameters + ---------- + s_xx : array-like, shape (n_freqs, n_times) + The PSD of channel 'x'. + s_yy : array-like, shape (n_freqs, n_times) + The PSD of channel 'y'. + s_xy : array-like, shape (n_freqs, n_times) + The cross PSD between channel 'x' and channel 'y' across + frequency and time points. + + Returns + ------- + cohy : array-like, shape (n_freqs, n_times) + The estimated COHY. + """ + con_num = s_xy.mean(axis=-1, keepdims=True) + con_den = np.sqrt( + s_xx.mean(axis=-1, keepdims=True) * s_yy.mean(axis=-1, keepdims=True) + ) + cohy = con_num / con_den + return cohy + + def _compute_csd(x, y, weights): """Compute cross spectral density between signals x and y.""" if weights is not None: diff --git a/mne_connectivity/tests/test_effective.py b/mne_connectivity/tests/test_effective.py index 0ff0a8d6..77a6ed4b 100644 --- a/mne_connectivity/tests/test_effective.py +++ b/mne_connectivity/tests/test_effective.py @@ -1,7 +1,7 @@ import numpy as np from numpy.testing import assert_array_almost_equal -from mne_connectivity.effective import phase_slope_index +from mne_connectivity.effective import phase_slope_index, phase_slope_index_time def test_psi(): @@ -39,3 +39,40 @@ def test_psi(): assert np.all(conn_cwt.get_data() > 0) assert conn_cwt.shape[-1] == n_times + + +def test_psi_time(): + """Test Phase Slope Index (PSI) estimation across time.""" + sfreq = 50.0 + n_signals = 3 + n_epochs = 10 + n_times = 500 + rng = np.random.RandomState(42) + data = rng.randn(n_epochs, n_signals, n_times) + + # simulate time shifts + for i in range(n_epochs): + data[i, 1, 10:] = data[i, 0, :-10] # signal 0 is ahead + data[i, 2, :-10] = data[i, 0, 10:] # signal 2 is ahead + + # conn = phase_slope_index_time(data, mode="fourier", sfreq=sfreq, freqs=np.arange(4)) + + # assert conn.get_data(output="dense")[1, 0, 0] < 0 + # assert conn.get_data(output="dense")[2, 0, 0] > 0 + + # # only compute for a subset of the indices + indices = (np.array([0]), np.array([1])) + # conn_2 = phase_slope_index_time(data, mode="fourier", sfreq=sfreq, freqs=np.arange(4), indices=indices) + + # # the measure is symmetric (sign flip) + # assert_array_almost_equal( + # conn_2.get_data()[0, 0], -conn.get_data(output="dense")[1, 0, 0] + # ) + + freqs = np.arange(5.0, 20, 0.5) + conn_cwt = phase_slope_index_time( + data, mode="cwt_morlet", sfreq=sfreq, freqs=freqs, indices=indices + ) + + assert np.all(conn_cwt.get_data() > 0) + assert conn_cwt.shape[0] == n_epochs