Source code for sekupy.analysis.states.base

import numpy as np
import os
from sekupy.utils.dataset import get_ds_data
from sekupy.analysis.base import Analyzer
from sekupy.analysis.states.subsamplers import VarianceSubsampler
from sklearn import cluster
from sklearn.pipeline import Pipeline
from scipy.spatial.distance import euclidean
import mne.time_frequency as freq

import logging
logger = logging.getLogger(__name__)


[docs] class Clustering(Analyzer): # TODO: Wrapper for sklearn clustering?? def __init__(self, estimator=cluster.KMeans(), name='state', **kwargs): self.estimator = estimator self._est_params = self._get_estimator_params(estimator) Analyzer.__init__(self, name=name)
[docs] def fit(self, ds, prepro=VarianceSubsampler(), **kwargs): """This method fits the dataset using the clustering algorithm Parameters ---------- ds : [type] [description] prepro : [type], optional [description], by default VarianceSubsampler() Attributes ------- scores: dict The results of the state identification 'labels': array array with the assigned cluster for the subsampled set of samples 'states': array the centroids of the clustered set using the subsampled dataset. 'dynamics': array the predicted labels of the full dataset using the fitted algorithm 'X': array the subsampled dataset 'targets': array the targets of the dataset (if applicable) 'state_similarity': array the similarity of the dataset with the most similar centroid """ # Check if estimator needs n_clusters ds_ = prepro.transform(ds) logger.info("Dataset shape %s" % (str(ds_.shape))) X, _ = get_ds_data(ds_) logger.debug(isinstance(self.estimator, Pipeline)) logger.debug(self.estimator) if isinstance(self.estimator, Pipeline): name, estimator = self.estimator.steps[0] else: estimator = self.estimator estimator = estimator.fit(X) self.scores = dict() if hasattr(estimator, 'labels_'): self.scores['labels'] = estimator.labels_ elif hasattr(estimator, 'predict'): # Gaussian Mixture self.scores['labels'] = estimator.predict(X) if hasattr(estimator, 'cluster_centers_'): self.scores['states'] = estimator.cluster_centers_ elif hasattr(estimator, 'means_'): # Gaussian Mixture self.scores['states'] = estimator.means_ elif hasattr(estimator, 'labels_'): # DBSCAN self.scores['states'] = self.get_centroids(X, self.scores['labels']) if hasattr(estimator, 'predict'): self.scores['dynamics'] = estimator.predict(ds.samples) else: # DBSCAN self.scores['dynamics'] = self._predict(ds.samples, self.scores['states']) _ = self._predict(ds.samples, self.scores['states']) self.scores['X'] = X self.scores['targets'] = ds.sa.targets self._info = self._store_info(ds) return self
[docs] def get_centroids(self, X, labels): """ Returns the centroid of a clustering experiment Parameters ---------- X : n_samples x n_features array The full dataset used for clustering labels : n_samples array The clustering labels for each sample. Returns ------- centroids : n_cluster x n_features shaped array The centroids of the clusters. """ return np.array([X[labels == l].mean(0) for l in np.unique(labels)])
def _predict(self, X, centroids, measure=euclidean, **kwargs): """ Returns the similarity of the dataset to each centroid, using a dissimilarity distance function. Parameters ---------- X : n_samples x n_features array The full dataset used for clustering centroids : n_cluster x n_features array The cluster centroids. measure : a scipy.spatial.distance function | default: euclidean This is the dissimilarity measure, this should be a python function, see scipy.spatial.distance. Returns ------- labels : n_samples array The array indicating the most similar cluster center for each sample. """ from sekupy.utils.math import similiarity dist = similiarity(centroids, X, measure=measure, **kwargs) labels, order = np.nonzero(dist.min(0) == dist) self.scores['state_similarity'] = dist labels_ = np.zeros_like(labels) labels_[order] = labels return labels_
[docs] def get_state_frequencies(self, dynamics=None, sfreq=128, method=freq.psd_array_welch, **kwargs): """ Returns the spectrum of the state occurence for each subject. Parameters ---------- state_dynamics : n_states x n_timepoints array The state dynamics output from fit_states function. method : function from mne.time_frequency package X : n_timepoinst x n_features array This is the array on which we need to build the state_dynamics whether not provided nor calculated. Returns ------- results : tuple, first element is the array of frequencies, second element is the array n_states x frequencies of the spectrum. """ if (dynamics is None): dynamics = self.scores['state_similarity'] return method(dynamics, sfreq, **kwargs)
[docs] def get_transition_matrix(self, dynamics=None): """ Extracts the probability transition matrix given the fitted timecourse for each subject. This is the output of fit_centroids function Parameters ---------- group_fitted_timecourse : n_subjects x n_timepoints array The state dynamics output from fit_centroids function. Returns ------- transition_p : n_states x n_states matrix, The number of transition from x state to y state, stored in matrix[x,y]. The syntax is matrix[starting_state, ending_state]. """ if dynamics is None: dynamics = self.scores['state_similarity'] n_states = dynamics.max() transitions = np.zeros((n_states+1, n_states+1)) for i, state in enumerate(dynamics): if (i+1) < len(dynamics): transitions[state, dynamics[i+1]] += 1 # transitions /= len(self._dynamics) - 1 return transitions
[docs] def get_state_duration(self, dynamics=None): """ Extracts the average of state duration given the fitted timecourse for each subject. This is the output of fit_centroids function Parameters ---------- dynamics : n_timepoints array The state dynamics output from _predict function. Returns ------- state_duration : vector n_transition x 2 matrix, The duration of each state for each subject/session. first element is the state number, second the duration in points """ if dynamics is None: dynamics = self.scores['state_similarity'] counter = 0 state_duration = [] for i, state in enumerate(dynamics): if (i+1) < len(dynamics): if state == dynamics[i+1]: counter += 1 else: counter += 1 state_duration.append([state, float(counter)]) counter = 0 return np.array(state_duration)
[docs] def save(self, path=None, **kwargs): from scipy.io import savemat params = dict() params.update(kwargs) params.update(self._est_params) path, prefix = Analyzer.save(self, path=path, **params) savemat(os.path.join(path, "results-%s.mat" % (prefix)), {'data': self.scores}) return path
def _get_estimator_params(self, estimator): if isinstance(self.estimator, Pipeline): name, estimator = self.estimator.steps[0] else: estimator = self.estimator params = estimator.__dict__.copy() params['algorithm'] = str(estimator).split('(')[0] return params def _check_fields(self): pass
[docs] def get_extrema_histogram(arg_extrema, n_timepoints): hist_arg = np.zeros(n_timepoints) n_subjects = len(np.unique(arg_extrema[0])) for i in range(n_subjects): sub_max_arg = arg_extrema[1][arg_extrema[0] == i] hist_arg[sub_max_arg] += 1 return hist_arg