You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

565 lines
19 KiB

4 years ago
  1. """Base classes for all estimators."""
  2. # Author: Gael Varoquaux <gael.varoquaux@normalesup.org>
  3. # License: BSD 3 clause
  4. import copy
  5. import warnings
  6. from collections import defaultdict
  7. import numpy as np
  8. from scipy import sparse
  9. from .externals import six
  10. from .utils.fixes import signature
  11. from . import __version__
  12. ##############################################################################
  13. def _first_and_last_element(arr):
  14. """Returns first and last element of numpy array or sparse matrix."""
  15. if isinstance(arr, np.ndarray) or hasattr(arr, 'data'):
  16. # numpy array or sparse matrix with .data attribute
  17. data = arr.data if sparse.issparse(arr) else arr
  18. return data.flat[0], data.flat[-1]
  19. else:
  20. # Sparse matrices without .data attribute. Only dok_matrix at
  21. # the time of writing, in this case indexing is fast
  22. return arr[0, 0], arr[-1, -1]
  23. def clone(estimator, safe=True):
  24. """Constructs a new estimator with the same parameters.
  25. Clone does a deep copy of the model in an estimator
  26. without actually copying attached data. It yields a new estimator
  27. with the same parameters that has not been fit on any data.
  28. Parameters
  29. ----------
  30. estimator : estimator object, or list, tuple or set of objects
  31. The estimator or group of estimators to be cloned
  32. safe : boolean, optional
  33. If safe is false, clone will fall back to a deep copy on objects
  34. that are not estimators.
  35. """
  36. estimator_type = type(estimator)
  37. # XXX: not handling dictionaries
  38. if estimator_type in (list, tuple, set, frozenset):
  39. return estimator_type([clone(e, safe=safe) for e in estimator])
  40. elif not hasattr(estimator, 'get_params'):
  41. if not safe:
  42. return copy.deepcopy(estimator)
  43. else:
  44. raise TypeError("Cannot clone object '%s' (type %s): "
  45. "it does not seem to be a scikit-learn estimator "
  46. "as it does not implement a 'get_params' methods."
  47. % (repr(estimator), type(estimator)))
  48. klass = estimator.__class__
  49. new_object_params = estimator.get_params(deep=False)
  50. for name, param in six.iteritems(new_object_params):
  51. new_object_params[name] = clone(param, safe=False)
  52. new_object = klass(**new_object_params)
  53. params_set = new_object.get_params(deep=False)
  54. # quick sanity check of the parameters of the clone
  55. for name in new_object_params:
  56. param1 = new_object_params[name]
  57. param2 = params_set[name]
  58. if param1 is not param2:
  59. raise RuntimeError('Cannot clone object %s, as the constructor '
  60. 'either does not set or modifies parameter %s' %
  61. (estimator, name))
  62. return new_object
  63. ###############################################################################
  64. def _pprint(params, offset=0, printer=repr):
  65. """Pretty print the dictionary 'params'
  66. Parameters
  67. ----------
  68. params : dict
  69. The dictionary to pretty print
  70. offset : int
  71. The offset in characters to add at the begin of each line.
  72. printer : callable
  73. The function to convert entries to strings, typically
  74. the builtin str or repr
  75. """
  76. # Do a multi-line justified repr:
  77. options = np.get_printoptions()
  78. np.set_printoptions(precision=5, threshold=64, edgeitems=2)
  79. params_list = list()
  80. this_line_length = offset
  81. line_sep = ',\n' + (1 + offset // 2) * ' '
  82. for i, (k, v) in enumerate(sorted(six.iteritems(params))):
  83. if type(v) is float:
  84. # use str for representing floating point numbers
  85. # this way we get consistent representation across
  86. # architectures and versions.
  87. this_repr = '%s=%s' % (k, str(v))
  88. else:
  89. # use repr of the rest
  90. this_repr = '%s=%s' % (k, printer(v))
  91. if len(this_repr) > 500:
  92. this_repr = this_repr[:300] + '...' + this_repr[-100:]
  93. if i > 0:
  94. if (this_line_length + len(this_repr) >= 75 or '\n' in this_repr):
  95. params_list.append(line_sep)
  96. this_line_length = len(line_sep)
  97. else:
  98. params_list.append(', ')
  99. this_line_length += 2
  100. params_list.append(this_repr)
  101. this_line_length += len(this_repr)
  102. np.set_printoptions(**options)
  103. lines = ''.join(params_list)
  104. # Strip trailing space to avoid nightmare in doctests
  105. lines = '\n'.join(l.rstrip(' ') for l in lines.split('\n'))
  106. return lines
  107. ###############################################################################
  108. class BaseEstimator(object):
  109. """Base class for all estimators in scikit-learn
  110. Notes
  111. -----
  112. All estimators should specify all the parameters that can be set
  113. at the class level in their ``__init__`` as explicit keyword
  114. arguments (no ``*args`` or ``**kwargs``).
  115. """
  116. @classmethod
  117. def _get_param_names(cls):
  118. """Get parameter names for the estimator"""
  119. # fetch the constructor or the original constructor before
  120. # deprecation wrapping if any
  121. init = getattr(cls.__init__, 'deprecated_original', cls.__init__)
  122. if init is object.__init__:
  123. # No explicit constructor to introspect
  124. return []
  125. # introspect the constructor arguments to find the model parameters
  126. # to represent
  127. init_signature = signature(init)
  128. # Consider the constructor parameters excluding 'self'
  129. parameters = [p for p in init_signature.parameters.values()
  130. if p.name != 'self' and p.kind != p.VAR_KEYWORD]
  131. for p in parameters:
  132. if p.kind == p.VAR_POSITIONAL:
  133. raise RuntimeError("scikit-learn estimators should always "
  134. "specify their parameters in the signature"
  135. " of their __init__ (no varargs)."
  136. " %s with constructor %s doesn't "
  137. " follow this convention."
  138. % (cls, init_signature))
  139. # Extract and sort argument names excluding 'self'
  140. return sorted([p.name for p in parameters])
  141. def get_params(self, deep=True):
  142. """Get parameters for this estimator.
  143. Parameters
  144. ----------
  145. deep : boolean, optional
  146. If True, will return the parameters for this estimator and
  147. contained subobjects that are estimators.
  148. Returns
  149. -------
  150. params : mapping of string to any
  151. Parameter names mapped to their values.
  152. """
  153. out = dict()
  154. for key in self._get_param_names():
  155. value = getattr(self, key, None)
  156. if deep and hasattr(value, 'get_params'):
  157. deep_items = value.get_params().items()
  158. out.update((key + '__' + k, val) for k, val in deep_items)
  159. out[key] = value
  160. return out
  161. def set_params(self, **params):
  162. """Set the parameters of this estimator.
  163. The method works on simple estimators as well as on nested objects
  164. (such as pipelines). The latter have parameters of the form
  165. ``<component>__<parameter>`` so that it's possible to update each
  166. component of a nested object.
  167. Returns
  168. -------
  169. self
  170. """
  171. if not params:
  172. # Simple optimization to gain speed (inspect is slow)
  173. return self
  174. valid_params = self.get_params(deep=True)
  175. nested_params = defaultdict(dict) # grouped by prefix
  176. for key, value in params.items():
  177. key, delim, sub_key = key.partition('__')
  178. if key not in valid_params:
  179. raise ValueError('Invalid parameter %s for estimator %s. '
  180. 'Check the list of available parameters '
  181. 'with `estimator.get_params().keys()`.' %
  182. (key, self))
  183. if delim:
  184. nested_params[key][sub_key] = value
  185. else:
  186. setattr(self, key, value)
  187. valid_params[key] = value
  188. for key, sub_params in nested_params.items():
  189. valid_params[key].set_params(**sub_params)
  190. return self
  191. def __repr__(self):
  192. class_name = self.__class__.__name__
  193. return '%s(%s)' % (class_name, _pprint(self.get_params(deep=False),
  194. offset=len(class_name),),)
  195. def __getstate__(self):
  196. try:
  197. state = super(BaseEstimator, self).__getstate__()
  198. except AttributeError:
  199. state = self.__dict__.copy()
  200. if type(self).__module__.startswith('sklearn.'):
  201. return dict(state.items(), _sklearn_version=__version__)
  202. else:
  203. return state
  204. def __setstate__(self, state):
  205. if type(self).__module__.startswith('sklearn.'):
  206. pickle_version = state.pop("_sklearn_version", "pre-0.18")
  207. if pickle_version != __version__:
  208. warnings.warn(
  209. "Trying to unpickle estimator {0} from version {1} when "
  210. "using version {2}. This might lead to breaking code or "
  211. "invalid results. Use at your own risk.".format(
  212. self.__class__.__name__, pickle_version, __version__),
  213. UserWarning)
  214. try:
  215. super(BaseEstimator, self).__setstate__(state)
  216. except AttributeError:
  217. self.__dict__.update(state)
  218. ###############################################################################
  219. class ClassifierMixin(object):
  220. """Mixin class for all classifiers in scikit-learn."""
  221. _estimator_type = "classifier"
  222. def score(self, X, y, sample_weight=None):
  223. """Returns the mean accuracy on the given test data and labels.
  224. In multi-label classification, this is the subset accuracy
  225. which is a harsh metric since you require for each sample that
  226. each label set be correctly predicted.
  227. Parameters
  228. ----------
  229. X : array-like, shape = (n_samples, n_features)
  230. Test samples.
  231. y : array-like, shape = (n_samples) or (n_samples, n_outputs)
  232. True labels for X.
  233. sample_weight : array-like, shape = [n_samples], optional
  234. Sample weights.
  235. Returns
  236. -------
  237. score : float
  238. Mean accuracy of self.predict(X) wrt. y.
  239. """
  240. from .metrics import accuracy_score
  241. return accuracy_score(y, self.predict(X), sample_weight=sample_weight)
  242. ###############################################################################
  243. class RegressorMixin(object):
  244. """Mixin class for all regression estimators in scikit-learn."""
  245. _estimator_type = "regressor"
  246. def score(self, X, y, sample_weight=None):
  247. """Returns the coefficient of determination R^2 of the prediction.
  248. The coefficient R^2 is defined as (1 - u/v), where u is the residual
  249. sum of squares ((y_true - y_pred) ** 2).sum() and v is the total
  250. sum of squares ((y_true - y_true.mean()) ** 2).sum().
  251. The best possible score is 1.0 and it can be negative (because the
  252. model can be arbitrarily worse). A constant model that always
  253. predicts the expected value of y, disregarding the input features,
  254. would get a R^2 score of 0.0.
  255. Parameters
  256. ----------
  257. X : array-like, shape = (n_samples, n_features)
  258. Test samples. For some estimators this may be a
  259. precomputed kernel matrix instead, shape = (n_samples,
  260. n_samples_fitted], where n_samples_fitted is the number of
  261. samples used in the fitting for the estimator.
  262. y : array-like, shape = (n_samples) or (n_samples, n_outputs)
  263. True values for X.
  264. sample_weight : array-like, shape = [n_samples], optional
  265. Sample weights.
  266. Returns
  267. -------
  268. score : float
  269. R^2 of self.predict(X) wrt. y.
  270. """
  271. from .metrics import r2_score
  272. return r2_score(y, self.predict(X), sample_weight=sample_weight,
  273. multioutput='variance_weighted')
  274. ###############################################################################
  275. class ClusterMixin(object):
  276. """Mixin class for all cluster estimators in scikit-learn."""
  277. _estimator_type = "clusterer"
  278. def fit_predict(self, X, y=None):
  279. """Performs clustering on X and returns cluster labels.
  280. Parameters
  281. ----------
  282. X : ndarray, shape (n_samples, n_features)
  283. Input data.
  284. y : Ignored
  285. not used, present for API consistency by convention.
  286. Returns
  287. -------
  288. labels : ndarray, shape (n_samples,)
  289. cluster labels
  290. """
  291. # non-optimized default implementation; override when a better
  292. # method is possible for a given clustering algorithm
  293. self.fit(X)
  294. return self.labels_
  295. class BiclusterMixin(object):
  296. """Mixin class for all bicluster estimators in scikit-learn"""
  297. @property
  298. def biclusters_(self):
  299. """Convenient way to get row and column indicators together.
  300. Returns the ``rows_`` and ``columns_`` members.
  301. """
  302. return self.rows_, self.columns_
  303. def get_indices(self, i):
  304. """Row and column indices of the i'th bicluster.
  305. Only works if ``rows_`` and ``columns_`` attributes exist.
  306. Parameters
  307. ----------
  308. i : int
  309. The index of the cluster.
  310. Returns
  311. -------
  312. row_ind : np.array, dtype=np.intp
  313. Indices of rows in the dataset that belong to the bicluster.
  314. col_ind : np.array, dtype=np.intp
  315. Indices of columns in the dataset that belong to the bicluster.
  316. """
  317. rows = self.rows_[i]
  318. columns = self.columns_[i]
  319. return np.nonzero(rows)[0], np.nonzero(columns)[0]
  320. def get_shape(self, i):
  321. """Shape of the i'th bicluster.
  322. Parameters
  323. ----------
  324. i : int
  325. The index of the cluster.
  326. Returns
  327. -------
  328. shape : (int, int)
  329. Number of rows and columns (resp.) in the bicluster.
  330. """
  331. indices = self.get_indices(i)
  332. return tuple(len(i) for i in indices)
  333. def get_submatrix(self, i, data):
  334. """Returns the submatrix corresponding to bicluster `i`.
  335. Parameters
  336. ----------
  337. i : int
  338. The index of the cluster.
  339. data : array
  340. The data.
  341. Returns
  342. -------
  343. submatrix : array
  344. The submatrix corresponding to bicluster i.
  345. Notes
  346. -----
  347. Works with sparse matrices. Only works if ``rows_`` and
  348. ``columns_`` attributes exist.
  349. """
  350. from .utils.validation import check_array
  351. data = check_array(data, accept_sparse='csr')
  352. row_ind, col_ind = self.get_indices(i)
  353. return data[row_ind[:, np.newaxis], col_ind]
  354. ###############################################################################
  355. class TransformerMixin(object):
  356. """Mixin class for all transformers in scikit-learn."""
  357. def fit_transform(self, X, y=None, **fit_params):
  358. """Fit to data, then transform it.
  359. Fits transformer to X and y with optional parameters fit_params
  360. and returns a transformed version of X.
  361. Parameters
  362. ----------
  363. X : numpy array of shape [n_samples, n_features]
  364. Training set.
  365. y : numpy array of shape [n_samples]
  366. Target values.
  367. Returns
  368. -------
  369. X_new : numpy array of shape [n_samples, n_features_new]
  370. Transformed array.
  371. """
  372. # non-optimized default implementation; override when a better
  373. # method is possible for a given clustering algorithm
  374. if y is None:
  375. # fit method of arity 1 (unsupervised transformation)
  376. return self.fit(X, **fit_params).transform(X)
  377. else:
  378. # fit method of arity 2 (supervised transformation)
  379. return self.fit(X, y, **fit_params).transform(X)
  380. class DensityMixin(object):
  381. """Mixin class for all density estimators in scikit-learn."""
  382. _estimator_type = "DensityEstimator"
  383. def score(self, X, y=None):
  384. """Returns the score of the model on the data X
  385. Parameters
  386. ----------
  387. X : array-like, shape = (n_samples, n_features)
  388. Returns
  389. -------
  390. score : float
  391. """
  392. pass
  393. class OutlierMixin(object):
  394. """Mixin class for all outlier detection estimators in scikit-learn."""
  395. _estimator_type = "outlier_detector"
  396. def fit_predict(self, X, y=None):
  397. """Performs outlier detection on X.
  398. Returns -1 for outliers and 1 for inliers.
  399. Parameters
  400. ----------
  401. X : ndarray, shape (n_samples, n_features)
  402. Input data.
  403. y : Ignored
  404. not used, present for API consistency by convention.
  405. Returns
  406. -------
  407. y : ndarray, shape (n_samples,)
  408. 1 for inliers, -1 for outliers.
  409. """
  410. # override for transductive outlier detectors like LocalOulierFactor
  411. return self.fit(X).predict(X)
  412. ###############################################################################
  413. class MetaEstimatorMixin(object):
  414. """Mixin class for all meta estimators in scikit-learn."""
  415. # this is just a tag for the moment
  416. ###############################################################################
  417. def is_classifier(estimator):
  418. """Returns True if the given estimator is (probably) a classifier.
  419. Parameters
  420. ----------
  421. estimator : object
  422. Estimator object to test.
  423. Returns
  424. -------
  425. out : bool
  426. True if estimator is a classifier and False otherwise.
  427. """
  428. return getattr(estimator, "_estimator_type", None) == "classifier"
  429. def is_regressor(estimator):
  430. """Returns True if the given estimator is (probably) a regressor.
  431. Parameters
  432. ----------
  433. estimator : object
  434. Estimator object to test.
  435. Returns
  436. -------
  437. out : bool
  438. True if estimator is a regressor and False otherwise.
  439. """
  440. return getattr(estimator, "_estimator_type", None) == "regressor"
  441. def is_outlier_detector(estimator):
  442. """Returns True if the given estimator is (probably) an outlier detector.
  443. Parameters
  444. ----------
  445. estimator : object
  446. Estimator object to test.
  447. Returns
  448. -------
  449. out : bool
  450. True if estimator is an outlier detector and False otherwise.
  451. """
  452. return getattr(estimator, "_estimator_type", None) == "outlier_detector"