Source code for sekupy.analysis.pipeline

import logging
import os
from sekupy.analysis.base import Analyzer
from sekupy.utils.time import get_time
from sekupy.utils.files import make_dir
from collections import Counter

logger = logging.getLogger(__name__)


[docs] class AnalysisPipeline(Analyzer): def __init__(self, configurator, name="base"): """This class is used to perform a general analysis based on the configuration that is specified by the parameter. (see ```sekupy.analysis.configurator.Configurator``` docs) Parameters ---------- configurator : ```sekupy.analysis.configurator.Configurator``` object used to specify the preprocessing and the analysis to be performed name : str, optional [description] (the default is "base") """ self._configurator = configurator self.name = name
[docs] def fit(self, ds=None, **kwargs): """Fit the analysis on the dataset. Parameters ---------- ds : pymvpa dataset The dataset is the input to the analysis. kwargs : dict Optional parameters for the analysis. """ objects = self._configurator.fit() self._loader = objects['loader'] self._transformer = objects['transformer'] self._estimator = objects['estimator'] logger.debug(self._estimator) if (ds is None) and (self._loader is not None): fetch_kw = self._configurator._get_function_kwargs(function="fetch") logger.info(fetch_kw) ds = self._loader.fetch(**fetch_kw) elif (ds is None) and (self._loader is None): raise Exception("You must specify a dataset or a loader in the Configurator!") self._ds = ds ds_ = self._transform(ds) _ = self._estimator.fit(ds_, **kwargs) return self
def _get_ds(self): if hasattr(self, '_ds'): return self._ds objects = self._configurator.fit() self._loader = objects['loader'] fetch_kw = self._configurator._get_function_kwargs(function="fetch") logger.info(fetch_kw) ds = self._loader.fetch(**fetch_kw) return ds def _transform(self, ds): self._configurator._default_options['ds__target_count_pre'] = Counter(ds.targets) # TODO: Is it useful?? ds_dict = {"ds.a.%s" % (k): v.value for k, v in ds.a.items()} self._configurator._default_options.update(ds_dict) for node in self._transformer.nodes: ds = node.transform(ds) if node.name in ['balancer', 'target_transformer']: key = 'ds__target_count_%s' % (node.name) self._configurator._default_options[key] = Counter(ds.targets) return ds
[docs] def save(self, path=None, subdir="0_results", save_ds=False, **kwargs): # TODO: Mantain subdir for compatibility purposes? # params = self._configurator._get_fname_info() # params.update(self._estimator._get_fname_info()) params = self._configurator._default_options logger.debug(params) if 'path' in kwargs.keys(): path = kwargs.pop("path") if 'path' in params.keys(): path = params.pop("path") params['pipeline'] = self.name params.update(kwargs) # Save results path = self._estimator.save(path=path, **params) if save_ds: self._save_ds(path=path) return
def _save_ds(self, path): id_ = self._configurator._default_options['id'] num_ = self._configurator._default_options['num'] name_ = self.name fname = "ds_pipeline-%s_id-%s_num-%s.gzip" % (name_, id_, num_) self._ds.save(os.path.join(path, fname), compression='gzip')