Source code for sekupy.analysis.decoding.temporal_decoding

import numpy as np

from sklearn.metrics._scorer import check_scoring
from sklearn.svm import SVC
from sklearn.preprocessing import LabelEncoder
from sklearn.pipeline import Pipeline
from sklearn.model_selection._split import LeaveOneGroupOut

from sekupy.utils.dataset import get_ds_data
from sekupy.utils.dataset import temporal_attribute_reshaping, \
    temporal_transformation

from sekupy.preprocessing import FeatureSlicer
from sekupy.analysis.decoding.roi_decoding import RoiDecoding
from sekupy.preprocessing.base import Transformer

from scipy.io.matlab.mio import savemat

from mne.decoding import GeneralizingEstimator
from imblearn.under_sampling import RandomUnderSampler
import logging
logger = logging.getLogger(__name__)


[docs] class TemporalDecoding(RoiDecoding): """Implement temporal generalization decoding analysis using an arbitrary type of classifier. see King 2014 TICS Parameters ----------- estimator : 'svr', 'svc', or an estimator object implementing 'fit' The object to use to fit the data n_jobs : int, optional. Default is -1. The number of CPUs to use to do the computation. -1 means 'all CPUs'. scoring : string or callable, optional The scoring strategy to use. See the scikit-learn documentation If callable, takes as arguments the fitted estimator, the test data (X_test) and the test target (y_test) if y is not None. permutation : int. Default is 0. The number of permutation to be performed. If the number is 0, no permutation is performed. cv : cross-validation generator, optional A cross-validation generator. If None, a 3-fold cross validation is used or 3-fold stratified cross-validation when y is supplied. verbose : int, optional Verbosity level. Defaut is False Attributes ----------- scores : dict. The dictionary of results for each roi selected. The key is the union of the name of the roi and the value(s). The value is a list of values, the number is equal to the permutations. """ def __init__(self, estimator=None, n_jobs=1, scoring='accuracy', cv=LeaveOneGroupOut(), permutation=0, verbose=1, **kwargs ): RoiDecoding.__init__(self, estimator=estimator, n_jobs=n_jobs, scoring=scoring, cv=cv, permutation=permutation, verbose=verbose, name='temporal_decoding', **kwargs ) if estimator is None: estimator = Pipeline(steps=[('clf', SVC(C=1, kernel='linear'))]) if not isinstance(estimator, Pipeline): estimator = Pipeline(steps=[('clf', estimator)]) self.estimator = GeneralizingEstimator(estimator) # This seems the only way to cope with this self.scoring = None def _get_data(self, ds, cv_attr, time_attr='frame', balancer=RandomUnderSampler(), **kwargs): import warnings warnings.warn("This function must be replaced by super function _get_data", DeprecationWarning) X, y = get_ds_data(ds) if len(X.shape) == 3: return RoiDecoding._get_data(self, ds, cv_attr, **kwargs) t_values = ds.sa[time_attr].value X, y = temporal_transformation(X, y, t_values) _ = balancer.fit_resample(X[:,:,0], y) indices = balancer.sample_indices_ indices = np.sort(indices) groups = None if cv_attr is not None: _reshape = temporal_attribute_reshaping if isinstance(cv_attr, list): groups = np.vstack([_reshape(ds.sa[att].value, t_values) for att in cv_attr]).T else: groups = _reshape(ds.sa[cv_attr].value, t_values) groups = groups[indices] X, y = X[indices], y[indices] logger.info(np.unique(y, return_counts=True)) return X, y, groups
[docs] def fit(self, ds, time_attr='frame', roi='all', roi_values=None, cv_attr=None, prepro=Transformer(), balancer=RandomUnderSampler(), return_splits=True, return_predictions=False, **kwargs): """General method to fit data""" super().fit(ds, cv_attr=cv_attr, roi=roi, roi_values=roi_values, prepro=prepro, return_predictions=return_predictions, return_splits=return_splits, time_attr=time_attr, balancer=balancer, **kwargs)
def _save_score(self, score, save_estimator=False): mat_file = dict() for key, value in score.items(): if key.find("test_") != -1: mat_file[key] = value elif key == 'estimator': mat_estimator = self._save_estimator(value, save_estimator) mat_file.update(mat_estimator) elif key == "splits": mat_splits = self._save_splits(value) mat_file.update(mat_splits) elif key == "split_name": mat_file['split_name'] = [s['test'] for s in value] return mat_file def _save_estimator(self, estimators, save_estimator): from joblib import dump mat_ = dict() mat_['weights'] = [] mat_['features'] = [] # For each fold for estimator in estimators: est_weights = [] est_features = [] estimators_ = estimator.estimators_ # For each timepoint for est in estimators_: if hasattr(est.named_steps['clf'], 'coef_'): w = est.named_steps['clf'].coef_ est_weights.append(w) if 'fsel' in est.named_steps.keys(): f = est.named_steps['fsel'].get_support() est_features.append(f) mat_['features'].append(est_features) mat_['weights'].append(est_weights) mat_['features'] = np.array(mat_['features']) mat_['weights'] = np.array(mat_['weights']) return mat_