import numpy as np
from sekupy.utils.dataset import get_ds_data
from sekupy.analysis.utils import get_rois, get_params
from sekupy.preprocessing import FeatureSlicer
from sekupy.analysis.base import Analyzer
from sekupy.preprocessing.base import Transformer
from scipy.io.matlab.mio import savemat
from scipy.spatial.distance import pdist
import logging
logger = logging.getLogger(__name__)
[docs]
class RSA(Analyzer):
"""Implement representational similarity analysis (RSA) using
an arbitrary type of similarity measure.
Parameters
-----------
n_jobs : int, optional. Default is -1.
The number of CPUs to use to do the computation. -1 means
'all CPUs'.
scoring : string or callable, optional
The scoring strategy to use. See the scikit-learn documentation
permutation : int. Default is 0.
The number of permutation to be performed.
If the number is 0, no permutation is performed.
verbose : int, optional
Verbosity level. Defaut is False
Attributes
-----------
scores : dict.
The dictionary of results for each roi selected.
The key is the union of the name of the roi and the value(s).
The value is a list of values, the number is equal to the permutations.
"""
def __init__(self,
n_jobs=1,
permutation=0,
verbose=1,
name='rsa',
**kwargs):
self.n_jobs = n_jobs
self.permutation = permutation
self.verbose = verbose
Analyzer.__init__(self,
name=name,
**kwargs,
)
[docs]
def fit(self, ds,
roi='all',
roi_values=None,
metric='euclidean',
prepro=Transformer(),
**kwargs):
"""Fits the RSA on the dataset
Parameters
----------
ds : [type]
[description]
cv_attr : str, optional
[description] (the default is 'chunks', which [default_description])
roi : list, optional
list of strings that must be present in ds.fa keys
(the default is 'all', which [default_description])
roi_values : list, optional
A list of key, value tuple where the key is the
roi name, specified in ds.fa.roi and value is the value of the
subroi. (the default is None, which [default_description])
distance : str, optional
The metric to be used to calculate the dissimilarity. (default: euclidean)
prepro : [type], optional
[description] (the default is Transformer(), which [default_description])
return_predictions : bool, optional
[description] (the default is False, which [default_description])
return_splits : bool, optional
[description] (the default is True, which [default_description])
"""
if roi_values is None:
roi_values = get_rois(ds, roi)
scores = dict()
# TODO: How to use multiple ROIs
for r, value in roi_values:
ds_ = FeatureSlicer(**{r: value}).transform(ds)
ds_ = prepro.transform(ds_)
logger.info("Dataset shape %s" % (str(ds_.shape)))
X = ds_.samples
distance = pdist(X, metric=metric)
string_value = "+".join([str(v) for v in value])
scores["mask-%s_value-%s" % (r, string_value)] = distance
self._info = self._store_info(ds,
distance=metric,
roi=roi,
prepro=prepro)
self.scores = scores
self.distance = metric
self.conditions = ds.targets.copy()
return self
# Only in subclasses
def _get_analysis_info(self):
info = Analyzer._get_analysis_info(self)
info['roi'] = self._info['roi']
return info
[docs]
def save(self, path=None, **kwargs):
"""Save the results
Parameters
----------
path : str, optional
path where to store files (the default is
set up by :class:`sekupy.analysis.Analyzer`)
"""
import os
path, prefix = Analyzer.save(self, path=path, **kwargs)
kwargs.update({'prefix': prefix})
mat_score = dict()
mat_score['conditions'] = self.conditions
for roi, scores in self.scores.items():
kwargs.update({'mask': roi})
filename = self._get_filename(**kwargs)
logger.info("Saving %s" % (filename))
mat_score['test_score'] = scores
savemat(os.path.join(path, filename), mat_score)
return
def _get_filename(self, **kwargs):
"target-<values>_id-<datetime>_mask-<mask>_value-<roi_value>_data.mat"
logger.debug(kwargs)
params = {'distance': self.distance}
params_ = self._get_prepro_info(**kwargs)
params.update(params_)
logger.debug(params)
# TODO: Solve empty prefix
prefix = kwargs.pop('prefix')
midpart = "_".join(["%s-%s" % (k, str(v).replace("_", "+")) \
for k, v in params.items()])
trailing = kwargs.pop('mask')
filename = "%s_data.mat" % ("_".join([prefix, midpart,trailing]))
return filename