Source code for hyperit.hyperit

import numpy as np
from abc import ABC, abstractmethod
from scipy import stats
from typing import Tuple, List, Union, Any
from tqdm import tqdm
import matplotlib.pyplot as plt
import os
from pkg_resources import resource_filename

from jpype import isJVMStarted, startJVM, getDefaultJVMPath
from phyid.calculate import calc_PhiID
from phyid.utils import PhiID_atoms_abbr

from .utils import (
    setup_JArray, 
    ensure_three_dims,
    convert_names_to_indices, 
    convert_indices_to_names, 
    set_estimator
)

from enum import Enum

[docs]class MeasureType(Enum): MI = 'mi' # Mutual Information TE = 'te' # Transfer Entropy PhyID = 'phyid' # Integrated Information Decomposition def __str__(self): return self.value
[docs]class HyperIT: """ HyperIT: Hyperscanning Analyses using Information Theoretic Measures. HyperIT is equipped to compute pairwise and multivariate Mutual Information (MI), Transfer Entropy (TE), and Integrated Information Decomposition (ΦID) for continuous time-series data. Compatible for both intra-brain and inter-brain analyses and for both epoched and unepoched data. Multiple estimator choices and parameter customisations (via JIDT) are available, including Histogram/Binning, Gaussian, Kernel, KSG, and Symbolic. Integrated statistical significance testing using permutation/boostrapping approach. Visualisations of MI/TE matrices also provided. Args: data1 (np.ndarray): Time-series data for participant 1. Can take shape (n_epo, n_chan, n_samples) or (n_chan, n_samples) for epoched and unepoched data, respectively. data2 (np.ndarray): Time-series data for participant 2. Must have the same shape as data1. channel_names (List[str], optional): A list of strings representing the channel names for each participant. [[channel_names_p1], [channel_names_p2]] or [[channel_names_p1]]. standardise_data (bool, optional): Whether to standardise the data before analysis. Defaults to True. verbose (bool, optional): Whether constructor and analyses should output details and progress. Defaults to False. show_tqdm (bool, optional): Whether to show tqdm progress bars. Defaults to True. Note: This class requires numpy, mne, matplotlib, jpype (with the local infodynamics.jar file), and phyid as dependencies. Before a HyperIT can be created, users must first call HyperIT.setup_JVM() to initialise the Java Virtual Machine (JVM) with the local directory location of the infodynamics.jar file. Users can then create multiple HyperIT objects containing time-series data, later calling various functions for analysis. Automatic data checks for consistency and dimensionality, identifying whether analysis is to be intra- or inter-brain and epochality of data. - If data is 3 dimensional, data is assumed to be epoched with shape (epochs, channels, time_points). - If data is 2 dimensional, data is assumed to be unepoched with shape (channels, time_points). - If data is 1 dimensional, data is assumed to be single channel time series with shape (time_points). """ ## SETTING UP JVM _jvm_initialised = False
[docs] @classmethod def setup_JVM(cls, verbose: bool = False) -> None: """Setup JVM if not already started. To be called once before creating any instances.""" if not cls._jvm_initialised: if not isJVMStarted(): jarLocation = resource_filename(__name__, 'infodynamics.jar') startJVM(getDefaultJVMPath(), "-ea", ('-Djava.class.path=' + jarLocation)) cls._jvm_initialised = True if verbose: print("JVM started successfully.") else: if verbose: print("JVM already started.") else: if verbose: print("JVM setup already completed.")
def __init__(self, data1: np.ndarray, data2: np.ndarray, channel_names: List[str] = None, standardise_data: bool = False, verbose: bool = False, show_tqdm: bool = True) -> None: if not self.__class__._jvm_initialised: raise RuntimeError("JVM has not been started. Call setup_JVM() before creating any instances of HyperIT.") self._verbose = verbose self._show_tqdm = show_tqdm self._channel_names = channel_names or None self._channel_indices1 = [] self._channel_indices2 = [] self._standardise_data = standardise_data self._epoch_avg_later = False self._data1 = data1 self._data2 = data2 self.__setup() # Store original data (that has been checked and manipulated) for resetting self._orig_channel_indices_1 = self._channel_indices1 self._orig_channel_indices_2 = self._channel_indices2 if self._verbose: print("HyperIT object created successfully.") if self._n_epo > 1: print(f"Epoched data detected. \nAssuming data passed have shape ({self._n_epo} epochs, {self._n_chan} channels, {self._n_samples} time points).") else: print(f"Unepoched data detected. \nAssuming data passed have shape ({self._n_chan} channels, {self._n_samples} time points).") def __setup(self): self.__check_data() self.__check_channels() self.__configure_data() _, self._n_epo, self._n_chan, self._n_samples = self._all_data.shape self._it_matrix = None self._initialise_parameter = None # These are default when HyperIT object is instantiated. self._roi = [] self._scale_of_organisation = 1 # 1 = micro organisation (single channel pairwise), n = meso- or n-scale organisation (n-sized groups of channels) ## DUNDER METHODS def __repr__(self) -> str: """ String representation of HyperIT object. """ analysis_type = 'Hyperscanning' channel_info = f"{self._channel_names[0]} and {self._channel_names[0][1]}" # Assuming self._channel_names[0] is a list of channel names for the first data set return (f"HyperIT Object: \n" f"{analysis_type} Analysis with {self._n_epo} epochs, {self._n_chan} channels, " f"and {self._n_samples} time points. \n" f"Channel names passed: \n" f"{channel_info}.") def __len__(self) -> int: """ Returns the number of epochs in the HyperIT object. """ return self._all_data.shape def __str__(self) -> str: """ String representation of HyperIT object. """ return self.__repr__() ## DATA, CHANNEL, & INITIALISATION CHECKING def __check_data(self) -> None: """ Checks the consistency and dimensionality of the time-series data and channel names. Sets the number of epochs, channels, and time points as object variables. Ensures: - Whether an intra-brain analysis is expected (when no second data set is provided or when data1==data2). - Data are numpy arrays. - Data shapes are consistent. - Data dimensions are either 2 or 3 dimensional. - Channel names are in correct format and match number of channels in data. - Data are standardised, if specified. """ if self._data2 is None or np.array_equal(self._data1, self._data2) or self._data2.shape == (0,): self._data2 = self._data1.copy() self._include_intra = False if not all(isinstance(data, np.ndarray) for data in [self._data1, self._data2]): raise ValueError("Time-series data must be numpy arrays.") if self._data1.shape != self._data2.shape: raise ValueError("Time-series data must have the same shape for both participants.") if self._data1.ndim not in [1,2,3]: raise ValueError(f"Unexpected number of dimensions in time-series data: {self._data1.ndim}. Expected 3 dimensions (epochs, channels, time_points) or 2 dimensions (channels, time_points) or 1 dimension (time_points).") self._cannot_be_epoched = False if self._data1.ndim == 1: self._cannot_be_epoched = True # Ensure data is 3 dimensional and has shape (n_epochs, n_channels, n_samples). self._data1, self._data2 = map(ensure_three_dims, (self._data1, self._data2)) def __check_channel_names(self) -> None: """ Checks the consistency of the channel names provided. """ if not isinstance(self._channel_names, list): raise ValueError("Channel names must be a list of strings or a list of lists of strings.") elif isinstance(self._channel_names[0], str): if not all(isinstance(name, str) for name in self._channel_names): raise ValueError("All elements must be strings if the first element is a string.") self._channel_names = [self._channel_names, self._channel_names.copy()] elif isinstance(self._channel_names[0], list): if not all(isinstance(sublist, list) for sublist in self._channel_names): raise ValueError("All elements must be lists of strings if the first element is a list.") for sublist in self._channel_names: if not all(isinstance(name, str) for name in sublist): raise ValueError("All sublists must be lists of strings.") if len(self._channel_names) == 1: self._channel_names = self._channel_names * 2 else: raise ValueError("Channel names must be either a list of strings or a list of lists of strings.") if any(len(names) != self._n_chan for names in self._channel_names): raise ValueError("The number of channels in time-series data does not match the length of channel_names.") def __check_channels(self) -> None: """ Checks the consistency of the channel names provided and sets the number of channels as an object variable. """ self._n_chan = self._data1.shape[1] self._channel_indices1, self._channel_indices2 = np.arange(self._n_chan), np.arange(self._n_chan) if self._channel_names: self.__check_channel_names() else: self._channel_names = [np.arange(self._n_chan), np.arange(self._n_chan)] def __configure_data(self) -> None: """ Configures the data for analysis by standardising and storing original data to be recalled if roi resets. """ if self._standardise_data: self._data1 = (self._data1 - np.mean(self._data1, axis=-1, keepdims=True)) / np.std(self._data1, axis=-1, keepdims=True) self._data2 = (self._data2 - np.mean(self._data2, axis=-1, keepdims=True)) / np.std(self._data2, axis=-1, keepdims=True) self._all_data = np.stack([self._data1, self._data2], axis=0) # both data should now be three dimensional! ## DEFINING REGIONS OF INTEREST @property def roi(self) -> List[List[Union[str, int, list]]]: """Regions of interest for both data of the HyperIT object (defining spatial scale of organisation). To set this, call .roi(roi_list). HyperIT is defaulted to **micro-scale** analysis (individual channels) but specific channels can be specified for pointwise comparison: ``roi_list = [['Fp1', 'Fp2'], ['F3', 'F4']]``, for example. For **meso-scale** analysis (clusters of channels), equally-sized and equally-numbered clusters must be defined for both sets of data in the following way: ``roi_list = [[[PP1_cluster_1], ..., [PP1_cluster_n]], [[PP2_cluster_1], ..., [PP2_cluster_n]]]``. Finally, for **macro-scale** analysis (all channels per person), the specification can be set as ``roi_list = [[PP1_all_channels][PP2_all_channels]]`` (note that PP1_all_channels and PP2_all_channels should be list themselves). Importantly, as long as the ``channel_names`` are instantiated properly in the initiation of the HyperIT object, the ROI can even be given as a lists of channel indices (integers). In any case, to set these scales of organisations, simply amend the ``roi`` property of the HyperIT object used. Call ``roi.reset_roi()`` to reset the ROI to all channels. """ return self._roi @roi.setter def roi(self, roi_list: List[List[Union[str, int, list]]]) -> None: """Sets the region of interest for both data of the HyperIT object. Args: roi_list: A list of lists, where each sublist is a ROI containing either strings of EEG channel names or integer indices or multiple ROIs formed as another list. Raises: ValueError: If the value is not a list of lists, if elements of the sublists are not of type str or int, or if sublists do not have the same length. """ ## DETERMINE SCALE OF ORGANISATION # 1: Micro organisation (specified channels, pairwise comparison) e.g., roi_list = [['Fp1', 'Fp2'], ['F3', 'F4']] # n: Meso- or n-scale organisation (n specified channels per ROI group) e.g., roi_list = [[ ['Fp1', 'Fp2'], ['CP1', 'CP2'] ], n CHANNELS IN EACH GROUP FOR PARTICIPANT 1 # [ ['Fp1', 'Fp2'], ['F3', 'F4'] ]]. # Check if roi_list is structured for pointwise channel comparison if all(isinstance(sublist, list) and not any(isinstance(item, list) for item in sublist) for sublist in roi_list): self._scale_of_organisation = 1 # Check if roi_list is structured for grouped channel comparison elif all(isinstance(sublist, list) and all(isinstance(item, list) for item in sublist) for sublist in roi_list): # Ensure uniformity in the number of groups across both halves num_groups_x = len(roi_list[0]) num_groups_y = len(roi_list[1]) if num_groups_x == num_groups_y: self._soi_groups = num_groups_x group_lengths = [len(group) for half in roi_list for group in half] if len(set(group_lengths)) == 1: self._scale_of_organisation = group_lengths[0] self._initialise_parameter = (self._scale_of_organisation, self._scale_of_organisation) else: raise ValueError("Not all groups have the same number of channels.") else: raise ValueError("ROI halves do not have the same number of channel groups per participant.") else: raise ValueError("ROI structure is not recognised.") if self._verbose: print(f"Scale of organisation: {self._scale_of_organisation} channels.") print(f"Groups of channels: {self._soi_groups}") self._channel_indices1, self._channel_indices2 = [convert_names_to_indices(self._channel_names, part, idx) for idx, part in enumerate(roi_list)] self._roi = [self._channel_indices1, self._channel_indices2]
[docs] def reset_roi(self) -> None: """Resets the region of interest for both data of the HyperIT object to all channels.""" self._scale_of_organisation = 1 self._initialise_parameter = None self._channel_indices1, self._channel_indices2 = self._orig_channel_indices_1, self._orig_channel_indices_2 self._roi = []
## HARD-CODED HISTOGRAM AND SYMBOLIC MI ESTIMATION FUNCTIONS def __estimate_mi_hist(self, s1: np.ndarray, s2: np.ndarray) -> float: """Calculates Mutual Information using Histogram/Binning Estimator for time-series signals.""" def calc_fd_bins(X: np.ndarray, Y: np.ndarray) -> int: """Calculates the optimal frequency-distribution bin size for histogram estimator using Freedman-Diaconis Rule.""" fd_bins_X = np.ceil(np.ptp(X) / (2.0 * stats.iqr(X) * len(X)**(-1/3))) fd_bins_Y = np.ceil(np.ptp(Y) / (2.0 * stats.iqr(Y) * len(Y)**(-1/3))) fd_bins = int(np.ceil((fd_bins_X+fd_bins_Y)/2)) return fd_bins def hist_calc_mi(X, Y): # Joint probability distribution j_hist, _, _ = np.histogram2d(X, Y, bins=calc_fd_bins(X, Y)) pxy = j_hist / np.sum(j_hist) # Joint probability distribution # Marginals probability distribution px = np.sum(pxy, axis=1) py = np.sum(pxy, axis=0) # Entropies Hx = -np.sum(px * np.log2(px + np.finfo(float).eps)) Hy = -np.sum(py * np.log2(py + np.finfo(float).eps)) Hxy = -np.sum(pxy * np.log2(pxy + np.finfo(float).eps)) return Hx + Hy - Hxy result = hist_calc_mi(s1, s2) if self._calc_statsig: permuted_mi_values = [] for _ in range(self._stat_sig_perm_num): np.random.shuffle(s2) permuted_mi = hist_calc_mi(s1, s2) permuted_mi_values.append(permuted_mi) mean_permuted_mi = np.mean(permuted_mi_values) std_permuted_mi = np.std(permuted_mi_values) p_value = np.sum(permuted_mi_values >= result) / self._stat_sig_perm_num return np.array([result, mean_permuted_mi, std_permuted_mi, p_value]) return result def __estimate_mi_symb(self, s1: np.ndarray, s2: np.ndarray, k: int = 3, delay: int = 1) -> float: """Calculates Mutual Information using Symbolic Estimator for time-series signals.""" symbol_weights = np.power(k, np.arange(k)) def symb_symbolise(X: np.ndarray, k: int, delay: int) -> np.ndarray: Y = np.empty((k, len(X) - (k - 1) * delay)) for i in range(k): Y[i] = X[i * delay:i * delay + Y.shape[1]] return Y.T def symb_normalise_counts(d) -> None: total = sum(d.values()) return {key: value / total for key, value in d.items()} if total > 0 else d def symb_calc_mi(X: np.ndarray, Y: np.ndarray, k: int, delay: int) -> float: X_symb = symb_symbolise(X, k, delay).argsort(kind='quicksort') Y_symb = symb_symbolise(Y, k, delay).argsort(kind='quicksort') symbol_hash_X = (np.multiply(X_symb, symbol_weights)).sum(axis=1) symbol_hash_Y = (np.multiply(Y_symb, symbol_weights)).sum(axis=1) p_xy, p_x, p_y = [dict() for _ in range(3)] for i in range(len(symbol_hash_X)): xy = (symbol_hash_X[i], symbol_hash_Y[i]) x, y = symbol_hash_X[i], symbol_hash_Y[i] p_xy[xy] = p_xy.get(xy, 0) + 1 p_x[x] = p_x.get(x, 0) + 1 p_y[y] = p_y.get(y, 0) + 1 p_xy, p_x, p_y = [np.array(list(symb_normalise_counts(d).values())) for d in [p_xy, p_x, p_y]] entropy_X = -np.sum(p_x * np.log2(p_x + np.finfo(float).eps)) entropy_Y = -np.sum(p_y * np.log2(p_y + np.finfo(float).eps)) entropy_XY = -np.sum(p_xy * np.log2(p_xy + np.finfo(float).eps)) return entropy_X + entropy_Y - entropy_XY result = symb_calc_mi(s1, s2, k, delay) if self._calc_statsig: permuted_mi_values = [] for _ in range(self._stat_sig_perm_num): s2_permuted = np.random.permutation(s2) # Use permutation to avoid modifying s2 in place permuted_mi = symb_calc_mi(s1, s2_permuted, k, delay) permuted_mi_values.append(permuted_mi) mean_permuted_mi = np.mean(permuted_mi_values) std_permuted_mi = np.std(permuted_mi_values) p_value = np.sum(np.array(permuted_mi_values) >= result) / self._stat_sig_perm_num return np.array([result, mean_permuted_mi, std_permuted_mi, p_value]) return result ## ESTIMATION AND HELPER FUNCTIONS def __delay_timeseries(self, lag: int) -> None: """ Only for Kernel TE which currently has no built-in delay function. Manually delay data2 by specified lag""" newY = np.array(self._data2[..., lag:]) newX = np.array(self._data1[..., :-lag]) self._data1 = newX self._data2 = newY self._n_samples = newX.shape[-1] def __which_estimator(self, measure: str) -> None: """Determines the estimator to use based on the measure type and sets the estimator, properties, and initialisation parameters.""" self._estimator_name, calculator, properties, initialise_parameter = set_estimator(self._estimator, measure, self._params) # from utils.py function if calculator: self._Calc = calculator() if properties: for key, value in properties.items(): self._Calc.setProperty(key, value) if initialise_parameter: self._initialise_parameter = initialise_parameter if self._measure == MeasureType.TE and self._estimator_name == 'kernel': __delay_timeseries(int(self._params.get("delay",1))) def __setup_matrix(self) -> None: """ Sets up the matrices for Mutual Information, Transfer Entropy, or Integrated Information Decomposition. """ # POINTWISE CHANNEL COMPARISON (If ROI is selected, this will pick specific channels, otherwise data stays the same) if self._scale_of_organisation == 1: self._it_data1 = self._data1[:, self._channel_indices1, :] self._it_data2 = self._data2[:, self._channel_indices2, :] self._n_chan = len(self._channel_indices1) else: self._it_data1 = self._data1.copy() self._it_data2 = self._data2.copy() self._loop_range = self._n_chan if self._scale_of_organisation == 1 else self._soi_groups if self._include_intra: self._loop_range *= 2 # Will give shape (n_epo, n_chan_or_group*2, n_samples) # Note that data1 and data2 will be identical temp1 = self._it_data1.copy() temp2 = self._it_data2.copy() self._it_data1 = np.concatenate((temp1, temp2), axis=1) self._it_data2 = np.concatenate((temp1, temp2), axis=1) if self._scale_of_organisation != 1: temp_list = [self._roi[0],[[item + self._n_chan for item in sublist] for sublist in self._roi[1]]] self._roi = [sublist for outer_list in temp_list for sublist in outer_list] if self._measure == MeasureType.MI or self._measure == MeasureType.TE: if self._calc_statsig: self._it_matrix = np.zeros((1, self._loop_range, self._loop_range, 4)) if self._epoch_average else np.zeros((self._n_epo, self._loop_range, self._loop_range, 4)) return self._it_matrix = np.zeros((1, self._loop_range, self._loop_range)) if self._epoch_average else np.zeros((self._n_epo, self._loop_range, self._loop_range)) return self._it_matrix = np.zeros((1, self._loop_range, self._loop_range, 16)) if self._epoch_average else np.zeros((self._n_epo, self._loop_range, self._loop_range, 16)) def __initialise_estimator(self) -> None: if not self._initialise_parameter: self._Calc.initialise() return if self._measure == MeasureType.TE: if self._estimator == 'symbolic': # symbolic estimator takes only one argument so cannot be unrolled. self._Calc.initialise(self._initialise_parameter) return self._Calc.initialise(*self._initialise_parameter) def __estimate_it(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray: """ Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators. """ self._Calc.setObservations(setup_JArray(s1), setup_JArray(s2)) result = self._Calc.computeAverageLocalOfObservations() * np.log(2) # Conduct significance testing if self._calc_statsig: stat_sig = self._Calc.computeSignificance(self._stat_sig_perm_num) return np.array([result, stat_sig.getMeanOfDistribution(), stat_sig.getStdOfDistribution(), stat_sig.pValue]) return float(result) def __estimate_it_epoch_average(self, s1: np.ndarray, s2: np.ndarray) -> float: """ Estimates Mutual Information or Transfer Entropy for a pair of time-series signals using JIDT estimators. s1 and s2 should have shape (epochs, samples) referring to a pairwise comparison of two channels.""" self._Calc.startAddObservations() for epoch in range(self._n_epo): self._Calc.addObservations(setup_JArray(s1[epoch]), setup_JArray(s2[epoch])) self._Calc.finaliseAddObservations() result = self._Calc.computeAverageLocalOfObservations() * np.log(2) # Conduct significance testing if self._calc_statsig: stat_sig = self._Calc.computeSignificance(self._stat_sig_perm_num) return np.array([result, stat_sig.getMeanOfDistribution(), stat_sig.getStdOfDistribution(), stat_sig.pValue]) return result def __estimate_atoms(self, s1: np.ndarray, s2: np.ndarray) -> np.ndarray: """ Estimates Integrated Information Decomposition for a pair of time-series signals using phyid package. If epoch_average, s1 and s2 should have shape (epochs, samples) referring to a pairwise comparison of two channels.""" try: atoms_results_xy, _ = calc_PhiID(s1, s2, tau=self._tau, kind='gaussian', redundancy=self._redundancy) except Exception as e: if self._verbose: print(f'Warning: error handling timeseries. They are likely identical or similar timeseries. Setting results to 0. Error: {e}', flush=True) return np.zeros(16) calc_atoms_xy = np.mean(np.array([atoms_results_xy[_] for _ in PhiID_atoms_abbr]), axis=1) return calc_atoms_xy def __filter_estimation(self, s1: np.ndarray, s2: np.ndarray) -> float | np.ndarray: """ Filters the estimation in case incompatible with JIDT. s1 and s2 should have shape (samples) if epoch_average = False; or (samples, epochs) if epoch_average = True, referring to a pairwise comparison of two channels. """ if self._measure == MeasureType.MI: match self._estimator: case 'histogram': return self.__estimate_mi_hist(s1, s2) case 'symbolic': return self.__estimate_mi_symb(s1, s2, self._params.get('k', 3), self._params.get('delay', 1)) elif self._measure == MeasureType.PhyID: return self.__estimate_atoms(s1, s2) self.__initialise_estimator() return self.__estimate_it_epoch_average(s1, s2) if self._epoch_average else self.__estimate_it(s1, s2) def __compute_pair_or_group(self, epoch: int, i: int, j: int) -> None: """ Computes the Mutual Information or Transfer Entropy for a pair of channels or groups of channels. """ channel_or_group_i = i if self._scale_of_organisation == 1 else self._roi[i] channel_or_group_j = j if self._scale_of_organisation == 1 else self._roi[j] # Data needs to have shape (samples, channels/groups), if not pointwise comparison # (this is how both JIDT and phyid handle and expect incoming multivariate data) # (Note that .T does not affect pointwise comparison as it is already in the correct shape) # If epoch_average, epoch = 0 (this dimension will later be squeezed out but left here for similarity in data structure handling) if self._epoch_average: s1, s2 = self._it_data1[:, channel_or_group_i, :].T, self._it_data2[:, channel_or_group_j, :].T else: s1, s2 = self._it_data1[epoch, channel_or_group_i, :].T, self._it_data2[epoch, channel_or_group_j, :].T if self._include_intra: # Avoid self-connections in matrix if i == j: return # MI is symmetric so only compute half of matrix if self._measure == MeasureType.MI: if i > j: result = self.__filter_estimation(s1, s2) self._it_matrix[epoch, i, j] = result self._it_matrix[epoch, j, i] = result return else: return self._it_matrix[epoch, i, j] = self.__filter_estimation(s1, s2) return # ELSE: # MI is symmetric so only compute half of matrix if self._measure == MeasureType.MI: # The diagonal should be computed! (As this refers to B1Ch1 --> B2Ch1 rather than B1Ch1 --> B1Ch1) if i >= j: result = self.__filter_estimation(s1, s2) self._it_matrix[epoch, i, j] = result self._it_matrix[epoch, j, i] = result return else: return self._it_matrix[epoch, i, j] = self.__filter_estimation(s1, s2) return def __build_matrix(self) -> None: if self._epoch_average: for i in range(self._loop_range): for j in range(self._loop_range): self.__compute_pair_or_group(0, i, j) return for epoch in range(self._n_epo): tqdm_desc = f"Computing Epoch {epoch+1}/{self._n_epo}..." for i in tqdm(range(self._loop_range), desc = tqdm_desc, disable = not self._show_tqdm): for j in range(self._loop_range): self.__compute_pair_or_group(epoch, i, j) ## MAIN CALCULATION FUNCTIONS def __main_calc(self) -> np.ndarray: self.__setup_matrix() self.__build_matrix() if self._epoch_average: self._it_matrix = np.squeeze(self._it_matrix, axis=0) # remove redundant epoch dimension to give shape (2*ch, 2*ch, :) if self._epoch_avg_later: self._it_matrix = np.squeeze(np.mean(self._it_matrix, axis=0)) if self._vis: self.__prepare_vis() return self._it_matrix def __setup_mite_calc(self, estimator: str, include_intra: bool, calc_statsig: bool, stat_sig_perm_num: int, p_threshold: float, epoch_average: bool, vis: bool, plot_epochs: List, **kwargs) -> np.ndarray: """ General function for computing Mutual Information or Transfer Entropy. """ self._estimator = estimator.lower() if not (self._measure == MeasureType.MI and estimator.lower() == 'ksg') else 'ksg1' self._calc_statsig = calc_statsig self._include_intra = include_intra self._vis = vis self._plot_epochs = (plot_epochs or None) if self._vis and not epoch_average else None self._params = kwargs self._stat_sig_perm_num = stat_sig_perm_num self._p_threshold = p_threshold if epoch_average and self._cannot_be_epoched: raise ValueError("epoch_average cannot be true when data is 1-dimensional!") self._epoch_average = epoch_average self._epoch_avg_later = False if self._epoch_average and ((self._measure == MeasureType.MI and self._estimator in ["histogram", "symbolic"]) or (self._roi != [])): self._epoch_average = False self._epoch_avg_later = True # Set up the estimator and properties self.__which_estimator(str(self._measure)) return self.__main_calc() def __setup_atom_calc(self, tau: int, redundancy: str, include_intra: bool) -> np.ndarray: """ General function for computing Integrated Information Decomposition. """ self._tau = tau self._redundancy = redundancy self._include_intra = include_intra self._vis = False return self.__main_calc() ## VISUALISATION FUNCTIONS def __prepare_vis(self) -> None: """ Prepares the visualisation of Mutual Information, Transfer Entropy, or Integrated Information Decomposition matrix/matrices. """ if self._plot_epochs is None or self._plot_epochs == [-1]: self._plot_epochs = range(self._n_epo) else: self._plot_epochs = [ep - 1 for ep in self._plot_epochs if ep < self._n_epo] if not self._plot_epochs: raise ValueError("No valid epochs found in the input list.") if self._epoch_average or self._epoch_avg_later: self._plot_epochs = None if self._n_chan == 1: print("Single channel detected. No visualisation possible.") return if self._verbose: print(f"Plotting {self._measure_title}...") self.__plot_it() def __plot_matrix(self, results: np.ndarray, epoch: int, title: str, global_max: float, source_channel_names: List, target_channel_names: List, source_str: List, target_str: List) -> None: plt.figure(figsize=(12, 10)) img = plt.imshow(results, cmap='BuPu', vmin=0, vmax=global_max, aspect='auto') # if epoch is not None: # plt.title(f'{title} Epoch {epoch+1}', pad=20) # else: plt.title(title, pad=20) if self._calc_statsig: for i in range(self._loop_range): for j in range(self._loop_range): if i == j: continue p_val = float(self._it_matrix[epoch, i, j, 3]) if epoch is not None else float(self._it_matrix[i, j, 3]) if p_val < self._p_threshold: value = self._it_matrix[epoch, i, j, 0] if epoch is not None else self._it_matrix[i, j, 0] normalised_value = (value - np.min(results)) / (np.max(results) - np.min(results)) text_colour = 'white' if normalised_value > 0.5 else 'black' p_value_text = f'p={p_val:.3f}' if p_value_text == 'p=0.000': p_value_text = 'p<0.001' plt.text(j, i, p_value_text, ha='center', va='center', color=text_colour, fontsize=8, fontweight='bold') cbar = plt.colorbar(img) # cbar.set_label(self._measure_title, rotation=270, labelpad=20) n_ticks = 8 ticks = np.linspace(0, global_max, n_ticks) cbar.set_ticks(ticks) cbar.set_ticklabels([f"{tick:.2f}" if tick != global_max else f"{tick:.2f} (max)" for tick in ticks]) x_tick_label = target_channel_names.copy() y_tick_label = source_channel_names.copy() ticks = range(self._loop_range) rotate_x = 90 rotate_y = 0 if self._include_intra: if self._scale_of_organisation != 1: x_tick_label = source_str + target_str else: # x_tick_label = ['X_' + str(s) for s in y_tick_label] + ['Y_' + str(s) for s in x_tick_label] x_tick_label = [str(s) for s in y_tick_label] + [str(s) for s in x_tick_label] y_tick_label = x_tick_label else: plt.xlabel('Target') plt.ylabel('Source') plt.xticks(ticks, x_tick_label, rotation=rotate_x) plt.yticks(ticks, y_tick_label, rotation=rotate_y) plt.tick_params(axis='x', which='both', bottom=False, top=False, labelbottom=False, labeltop=True) plt.tick_params(axis='y', which='both', right=False, left=True, labelleft=True) plt.show() def __plot_it(self) -> None: """Plots the Mutual Information or Transfer Entropy matrix for visualisation. Axes labelled with source and target channel names. """ title = f'{self._measure_title} | {self._estimator_name} \n' source_channel_names = convert_indices_to_names(self._channel_names, self._channel_indices1, 0) target_channel_names = convert_indices_to_names(self._channel_names, self._channel_indices2, 1) source_str = [] target_str = [] global_max = np.max(self._it_matrix[..., 0]) if self._calc_statsig else np.max(self._it_matrix) if self._scale_of_organisation > 1: print("Plotting for grouped channels.") print("Source Groups:") for i in range(self._soi_groups): source_str.append(f'X{i+1}_{source_channel_names[i]}') print(f"{i+1}: {source_channel_names[i]}") print("\nTarget Groups:") for i in range(self._soi_groups): target_str.append(f'Y{i+1}_{target_channel_names[i]}') print(f"{i+1}: {target_channel_names[i]}") if self._plot_epochs is None: results = self._it_matrix[:, :, 0] if self._calc_statsig else self._it_matrix[:, :] self.__plot_matrix(results, None, title, global_max, source_channel_names, target_channel_names, source_str, target_str) return for epoch in self._plot_epochs: results = self._it_matrix[epoch, :, :, 0] if self._calc_statsig else self._it_matrix[epoch, :, :] self.__plot_matrix(results, epoch, title, global_max, source_channel_names, target_channel_names, source_str, target_str) ## HIGH-LEVEL INTERFACE FUNCTIONS
[docs] def compute_mi(self, estimator: str = 'gaussian', include_intra: bool = False, epoch_average: bool = True, calc_statsig: bool = False, stat_sig_perm_num: int = 100, p_threshold: float = 0.05, vis: bool = False, plot_epochs: List[int] = None, **kwargs) -> np.ndarray: """ Computes Mutual Information (MI) between data (time-series signals) using specified estimator. Args: estimator (str, optional): Specifies the MI estimator to use. Defaults to 'gaussian'. include_intra (bool, optional): If True, includes intra-brain analyses. Defaults to False. epoch_average (bool, optional): If True, results are averaged across epochs. Defaults to True. calc_statsig (bool, optional): If True, performs statistical significance testing. Defaults to False. stat_sig_perm_num (int, optional): Number of permutations for statistical significance testing. Defaults to 100. p_threshold (float, optional): Threshold for statistical significance testing. Defaults to 0.05. vis (bool, optional): If True, results will be visualised. Defaults to False. plot_epochs (List[int], optional): Specifies which epochs to plot. None plots all. Defaults to None. **kwargs : Additional keyword arguments for MI estimators. Returns: np.ndarray: A symmetric mutual information matrix. The shape of the matrix is determined by the `include_intra`, `epoch_average`, and `calc_statsig` flags: - If `include_intra` is False, shape is (..., n_chan, n_chan, ...). - If `include_intra` is True, shape is (..., 2*n_chan, 2*n_chan, ...). - If `epoch_average` is True, shape is (n_chan, n_chan, ...) or (2*n_chan, 2*n_chan, ...). - If `epoch_average` is False, shape is (n_epo, n_chan, n_chan, ...) or (n_epo, 2*n_chan, 2*n_chan, ...). - If `calc_statsig` is True, an additional last dimension (size 4) contains statistical significance results: [MI value, mean, standard deviation, p-value]. Note: When `include_intra` is True, the matrix can be segmented accordingly: - `intra1`: matrix[:, :, :n_chan, :n_chan] - `intra2`: matrix[:, :, n_chan:, n_chan:] - `inter12`: matrix[:, :, :n_chan, n_chan:] - `inter21`: matrix[:, :, n_chan:, :n_chan] Available Estimators and Their Parameters: - histogram: - None. - ksg1: - kraskov_param (int, default=4). - normalise (bool, default=True). - ksg2: - kraskov_param (int, default=4). - normalise (bool, default=True). - kernel: - kernel_width (float, default=0.25). - normalise (bool, default=True). - gaussian: - normalise (bool, default=True). - symbolic: - k (int, default=3): Embedding history or symbol length. - delay (int, default=1). """ self._measure = MeasureType.MI self._measure_title = 'Mutual Information' return self.__setup_mite_calc(estimator, include_intra, calc_statsig, stat_sig_perm_num, p_threshold, epoch_average, vis, plot_epochs, **kwargs)
[docs] def compute_te(self, estimator: str = 'gaussian', include_intra: bool = False, epoch_average: bool = True, calc_statsig: bool = False, stat_sig_perm_num: int = 100, p_threshold: float = 0.05, vis: bool = False, plot_epochs: List[int] = None, **kwargs) -> np.ndarray: """ Computes Transfer Entropy (TE) between time-series data using a specified estimator. This function allows for intra-brain and inter-brain analyses and includes optional statistical significance testing. Data1 is considered the source and Data2 the target. Args: estimator (str, optional): Specifies the TE estimator to use. Defaults to 'gaussian'. include_intra (bool, optional): Whether to include intra-brain comparisons in the output matrix. Defaults to False. epoch_average (bool, optional): If True, results are averaged across epochs. Defaults to True. calc_statsig (bool, optional): Whether to calculate statistical significance of TE values. Defaults to False. stat_sig_perm_num (int, optional): Number of permutations for statistical significance testing. Defaults to 100. p_threshold (float, optional): Threshold for statistical significance testing. Defaults to 0.05. vis (bool, optional): Enables visualisation of the TE matrix if set to True. Defaults to False. plot_epochs (List[int], optional): Specifies which epochs to plot. If None, plots all epochs. Defaults to None. **kwargs : Additional parameters for TE estimators. Returns: np.ndarray: A transfer entropy matrix. The shape of the matrix is determined by the `include_intra`, `epoch_average`, and `calc_statsig` flags: - If `include_intra` is False, shape is (..., n_chan, n_chan, ...). - If `include_intra` is True, shape is (..., 2*n_chan, 2*n_chan, ...). - If `epoch_average` is True, shape is (n_chan, n_chan, ...) or (2*n_chan, 2*n_chan, ...). - If `epoch_average` is False, shape is (n_epo, n_chan, n_chan, ...) or (n_epo, 2*n_chan, 2*n_chan, ...). - If `calc_statsig` is True, an additional last dimension (size 4) contains statistical significance results: [MI value, mean, standard deviation, p-value]. Note: When `include_intra` is True, the matrix can be segmented accordingly: - `intra1`: matrix[:, :, :n_chan, :n_chan] - `intra2`: matrix[:, :, n_chan:, n_chan:] - `inter12`: matrix[:, :, :n_chan, n_chan:] - `inter21`: matrix[:, :, n_chan:, :n_chan] Available Estimators and Their Parameters: - `ksg`: - k, k_tau (int, default=1): Target and source embedding history length. - l, l_tau (int, default=1): Target and source embedding delay. - delay (int, default=1): Delay parameter for temporal dependency. - kraskov_param (int, default=1). - normalise (bool, default=True). - `kernel`: - kernel_width (float, default=0.5). - delay (int, default=1) - normalise (bool, default=True). - `gaussian`: - k, k_tau (int, default=1): Target and source embedding history length. - l, l_tau (int, default=1): Target and source embedding delay. - delay (int, default=1): Delay parameter for temporal dependency. - bias_correction (bool, default=False). - normalise (bool, default=True). - `symbolic`: - k (int, default=1): Embedding history length. - normalise (bool, default=True). """ self._measure = MeasureType.TE self._measure_title = 'Transfer Entropy' return self.__setup_mite_calc(estimator, include_intra, calc_statsig, stat_sig_perm_num, p_threshold, epoch_average, vis, plot_epochs, **kwargs)
[docs] def compute_atoms(self, tau: int = 1, redundancy: str = 'MMI', include_intra: bool = False, epoch_average: bool = True) -> Tuple[np.ndarray, np.ndarray]: """ Function to compute Integrated Information Decomposition (ΦID) between data (time-series signals) instantiated in the HyperIT object. Args: tau (int, optional): Time-lag parameter. Defaults to 1. redundancy (str, optional): Redundancy function to use. Defaults to 'MMI' (Minimum Mutual Information). include_intra (bool, optional): Whether to include intra-brain analysis. Defaults to False. epoch_average (bool, optional): If True, results are averaged across epochs. Defaults to True. Returns: np.ndarray: A matrix of integrated information decomposition atoms. The shape of the matrix is determined by the `include_intra`, `epoch_average`, and `calc_statsig` flags: - If `include_intra` is False, shape is (..., n_chan, n_chan, ...). - If `include_intra` is True, shape is (..., 2*n_chan, 2*n_chan, ...). - If `epoch_average` is True, shape is (n_chan, n_chan, ...) or (2*n_chan, 2*n_chan, ...). - If `epoch_average` is False, shape is (n_epo, n_chan, n_chan, ...) or (n_epo, 2*n_chan, 2*n_chan, ...). - If `calc_statsig` is True, an additional last dimension (size 4) contains statistical significance results: [MI value, mean, standard deviation, p-value]. Note: When `include_intra` is True, the matrix can be segmented accordingly: - `intra1`: matrix[:, :, :n_chan, :n_chan] - `intra2`: matrix[:, :, n_chan:, n_chan:] - `inter12`: matrix[:, :, :n_chan, n_chan:] - `inter21`: matrix[:, :, n_chan:, :n_chan] Visualisation is not a possibility at the moment. Available Redundancy Functions: - 'MMI': Minimum Mutual Information - 'CCS': Common Change in Surprisal """ self._measure = MeasureType.PhyID self._measure_title = 'Integrated Information Decomposition' # Given phyid only takes (samples, n), cannot pass multiple channels (ROIs) with epochs. # Revert back to looping through each epoch and channel combination and average at the end. self._epoch_average = False if self._roi is not None else epoch_average self._epoch_avg_later = True return self.__setup_atom_calc(tau, redundancy, include_intra)